Skip to content

Commit e1a4632

Browse files
Provide an optimized rmsnorm kernel that fuses steps 1 and 2
1 parent 4541c14 commit e1a4632

File tree

2 files changed

+71
-24
lines changed

2 files changed

+71
-24
lines changed

src/main/java/com/example/tornadovm/Qwen3Kernels.java

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,49 @@ public static void rmsnormMapIndexInPlaceWithParallelOffset(
119119
// }
120120
}
121121

122+
/**
123+
* RmsNorm with parallel offset:
124+
*
125+
* Optimized kernel that combines Step 1 (Reduction) and Step 2 (Normalization).
126+
*/
127+
public static void rmsnormWithParallelOffset(
128+
KernelContext context,
129+
FloatArray output,
130+
FloatArray x,
131+
int localMemSize,
132+
int size,
133+
float ermsNorm) {
134+
135+
int gid = context.globalIdx;
136+
int lid = context.localIdx;
137+
int groupId = context.groupIdx;
138+
int groupSize = context.localGroupSizeX;
139+
140+
// Allocate local memory with the provided size
141+
float[] localX = context.allocateFloatLocalArray(localMemSize);
142+
143+
// Load input value and compute square
144+
localX[lid] = x.get(gid);
145+
localX[lid] = localX[lid] * localX[lid];
146+
147+
// Perform parallel reduction within the work group
148+
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
149+
context.localBarrier();
150+
if (lid < stride) {
151+
localX[lid] += localX[lid + stride];
152+
}
153+
}
154+
155+
// Each workgroup performs the normalization
156+
if (lid == 0) {
157+
// Store the partial sum from each workgroup
158+
localX[0] /= size;
159+
localX[0] += ermsNorm;
160+
localX[0] = 1.0f / TornadoMath.sqrt(localX[0]);
161+
output.set(groupId, localX[0]);
162+
}
163+
}
164+
122165
public static void reductionOneBlockWithLayerWithOffset(
123166
KernelContext context,
124167
FloatArray output,

src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -166,18 +166,20 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
166166
//rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
167167
unifiedLayer
168168
.task("rmsnormReduction_Qcur",
169-
Qwen3Kernels::rmsnormReductionWithOffset,
169+
Qwen3Kernels::rmsnormWithParallelOffset,
170170
context,
171171
state.tempQcur, // output
172172
state.wrapQ, // input
173-
state.localSize) // currently 128, should be variable of global nEmbHead
174-
.task("rmsnormFinalNormalization_Qcur",
175-
Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
176-
context,
177-
state.tempQcur, // output
178-
config.numberOfHeads(),
179-
nEmbdHead,
180-
config.rmsNormEps())
173+
state.localSize, // currently 128, should be variable of global nEmbHead
174+
nEmbdHead, // for normalization
175+
config.rmsNormEps()) // for normalization
176+
// .task("rmsnormFinalNormalization_Qcur",
177+
// Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
178+
// context,
179+
// state.tempQcur, // output
180+
// config.numberOfHeads(),
181+
// nEmbdHead,
182+
// config.rmsNormEps())
181183
.task("rmsnormMapIndexInPlace_Qcur",
182184
Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset,
183185
context,
@@ -192,18 +194,20 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
192194
//rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
193195
unifiedLayer
194196
.task("rmsnormReduction_Kcur",
195-
Qwen3Kernels::rmsnormReductionWithOffset,
197+
Qwen3Kernels::rmsnormWithParallelOffset,
196198
context,
197199
state.tempKcur, // output
198200
state.wrapK, // input
199-
state.localSize) // currently 128, should be variable of global nEmbHead
200-
.task("rmsnormFinalNormalization_Kcur",
201-
Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
202-
context,
203-
state.tempKcur, // output
204-
config.numberOfKeyValueHeads(),
205-
nEmbdHead,
206-
config.rmsNormEps())
201+
state.localSize, // currently 128, should be variable of global nEmbHead
202+
nEmbdHead, // for normalization
203+
config.rmsNormEps()) // for normalization
204+
// .task("rmsnormFinalNormalization_Kcur",
205+
// Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
206+
// context,
207+
// state.tempKcur, // output
208+
// config.numberOfKeyValueHeads(),
209+
// nEmbdHead,
210+
// config.rmsNormEps())
207211
.task("rmsnormMapIndexInPlace_Kcur",
208212
Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset,
209213
context,
@@ -359,8 +363,8 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
359363
WorkerGrid qCurWorker = new WorkerGrid1D(config.numberOfHeads() * nEmbdHead);
360364
qCurWorker.setLocalWork(nEmbdHead, 1, 1);
361365

362-
WorkerGrid qCurWorker2 = new WorkerGrid1D(config.numberOfHeads());
363-
qCurWorker2.setLocalWork(1, 1, 1);
366+
// WorkerGrid qCurWorker2 = new WorkerGrid1D(config.numberOfHeads());
367+
// qCurWorker2.setLocalWork(1, 1, 1);
364368

365369
// Kcur
366370
// config.numberOfKeyValueHeads() = 8
@@ -369,8 +373,8 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
369373
WorkerGrid kCurWorker = new WorkerGrid1D(config.numberOfKeyValueHeads() * nEmbdHead);
370374
kCurWorker.setLocalWork(nEmbdHead, 1, 1);
371375

372-
WorkerGrid kCurWorker2 = new WorkerGrid1D(config.numberOfKeyValueHeads());
373-
kCurWorker2.setLocalWork(1, 1, 1);
376+
// WorkerGrid kCurWorker2 = new WorkerGrid1D(config.numberOfKeyValueHeads());
377+
// kCurWorker2.setLocalWork(1, 1, 1);
374378

375379
int h = config.numberOfHeads();
376380
int ic = nEmbdHead / 2;
@@ -413,12 +417,12 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
413417

414418
// Qcur
415419
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker);
416-
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Qcur", qCurWorker2);
420+
//gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Qcur", qCurWorker2);
417421
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker);
418422

419423
// Kcur
420424
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker);
421-
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Kcur", kCurWorker2);
425+
//gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Kcur", kCurWorker2);
422426
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker);
423427

424428
gridScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker);

0 commit comments

Comments
 (0)