@@ -166,18 +166,20 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
166
166
//rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
167
167
unifiedLayer
168
168
.task ("rmsnormReduction_Qcur" ,
169
- Qwen3Kernels ::rmsnormReductionWithOffset ,
169
+ Qwen3Kernels ::rmsnormWithParallelOffset ,
170
170
context ,
171
171
state .tempQcur , // output
172
172
state .wrapQ , // input
173
- state .localSize ) // currently 128, should be variable of global nEmbHead
174
- .task ("rmsnormFinalNormalization_Qcur" ,
175
- Qwen3Kernels ::rmsnormFinalNormalizationWithParallelOffset ,
176
- context ,
177
- state .tempQcur , // output
178
- config .numberOfHeads (),
179
- nEmbdHead ,
180
- config .rmsNormEps ())
173
+ state .localSize , // currently 128, should be variable of global nEmbHead
174
+ nEmbdHead , // for normalization
175
+ config .rmsNormEps ()) // for normalization
176
+ // .task("rmsnormFinalNormalization_Qcur",
177
+ // Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
178
+ // context,
179
+ // state.tempQcur, // output
180
+ // config.numberOfHeads(),
181
+ // nEmbdHead,
182
+ // config.rmsNormEps())
181
183
.task ("rmsnormMapIndexInPlace_Qcur" ,
182
184
Qwen3Kernels ::rmsnormMapIndexInPlaceWithParallelOffset ,
183
185
context ,
@@ -192,18 +194,20 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
192
194
//rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
193
195
unifiedLayer
194
196
.task ("rmsnormReduction_Kcur" ,
195
- Qwen3Kernels ::rmsnormReductionWithOffset ,
197
+ Qwen3Kernels ::rmsnormWithParallelOffset ,
196
198
context ,
197
199
state .tempKcur , // output
198
200
state .wrapK , // input
199
- state .localSize ) // currently 128, should be variable of global nEmbHead
200
- .task ("rmsnormFinalNormalization_Kcur" ,
201
- Qwen3Kernels ::rmsnormFinalNormalizationWithParallelOffset ,
202
- context ,
203
- state .tempKcur , // output
204
- config .numberOfKeyValueHeads (),
205
- nEmbdHead ,
206
- config .rmsNormEps ())
201
+ state .localSize , // currently 128, should be variable of global nEmbHead
202
+ nEmbdHead , // for normalization
203
+ config .rmsNormEps ()) // for normalization
204
+ // .task("rmsnormFinalNormalization_Kcur",
205
+ // Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
206
+ // context,
207
+ // state.tempKcur, // output
208
+ // config.numberOfKeyValueHeads(),
209
+ // nEmbdHead,
210
+ // config.rmsNormEps())
207
211
.task ("rmsnormMapIndexInPlace_Kcur" ,
208
212
Qwen3Kernels ::rmsnormMapIndexInPlaceWithParallelOffset ,
209
213
context ,
@@ -359,8 +363,8 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
359
363
WorkerGrid qCurWorker = new WorkerGrid1D (config .numberOfHeads () * nEmbdHead );
360
364
qCurWorker .setLocalWork (nEmbdHead , 1 , 1 );
361
365
362
- WorkerGrid qCurWorker2 = new WorkerGrid1D (config .numberOfHeads ());
363
- qCurWorker2 .setLocalWork (1 , 1 , 1 );
366
+ // WorkerGrid qCurWorker2 = new WorkerGrid1D(config.numberOfHeads());
367
+ // qCurWorker2.setLocalWork(1, 1, 1);
364
368
365
369
// Kcur
366
370
// config.numberOfKeyValueHeads() = 8
@@ -369,8 +373,8 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
369
373
WorkerGrid kCurWorker = new WorkerGrid1D (config .numberOfKeyValueHeads () * nEmbdHead );
370
374
kCurWorker .setLocalWork (nEmbdHead , 1 , 1 );
371
375
372
- WorkerGrid kCurWorker2 = new WorkerGrid1D (config .numberOfKeyValueHeads ());
373
- kCurWorker2 .setLocalWork (1 , 1 , 1 );
376
+ // WorkerGrid kCurWorker2 = new WorkerGrid1D(config.numberOfKeyValueHeads());
377
+ // kCurWorker2.setLocalWork(1, 1, 1);
374
378
375
379
int h = config .numberOfHeads ();
376
380
int ic = nEmbdHead / 2 ;
@@ -413,12 +417,12 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
413
417
414
418
// Qcur
415
419
gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormReduction_Qcur" , qCurWorker );
416
- gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormFinalNormalization_Qcur" , qCurWorker2 );
420
+ // gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Qcur", qCurWorker2);
417
421
gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormMapIndexInPlace_Qcur" , qCurWorker );
418
422
419
423
// Kcur
420
424
gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormReduction_Kcur" , kCurWorker );
421
- gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormFinalNormalization_Kcur" , kCurWorker2 );
425
+ // gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Kcur", kCurWorker2);
422
426
gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormMapIndexInPlace_Kcur" , kCurWorker );
423
427
424
428
gridScheduler .addWorkerGrid ("layer_" + i + ".ropeRotation" , ropeWorker );
0 commit comments