2
2
3
3
import com .example .auxiliary .Tuple2 ;
4
4
import com .example .inference .state .Qwen3State ;
5
- import com .example .inference .state .State ;
6
5
import com .example .inference .weights .tornado .Qwen3TornadoWeights ;
7
6
import com .example .model .Model ;
8
7
import com .example .model .qwen3 .Qwen3Configuration ;
@@ -109,8 +108,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
109
108
config .dim (),
110
109
config .rmsNormEps (),
111
110
state .localSize )
112
- //.task("reductionFinalNormalization" , TransformerComputeKernelsLayered::reductionFinalNormalization, context,
113
- //state.temp, config.dim(), config.rmsNormEps())
114
111
.task ("mapContext" ,
115
112
TransformerComputeKernelsLayered ::reductionOneBlock2WithLayer ,
116
113
context ,
@@ -119,16 +116,9 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
119
116
weights .rms_att_weightLayered [layerIndex ],
120
117
state .temp );
121
118
122
- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
123
-
124
- // // dbg copy out
125
- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.temp);
126
- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
127
-
128
119
int qDim0 = nEmbdHeadK * config .numberOfHeads ();
129
120
int kvDim0 = nEmbdGqa ;
130
121
int qkvDim1 = config .dim ();
131
- //qkvMatmuls = new TaskGraph("qkvMatmuls_layer_" + layerIndex);
132
122
unifiedLayer .task ("qmatmul" ,
133
123
TransformerComputeKernelsLayered ::matrixVectorGeneric ,
134
124
context ,
@@ -157,11 +147,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
157
147
kvDim0 ,
158
148
LOCAL_WORK_GROUP_SIZE_ALLOC );
159
149
160
- // dbg copy out
161
- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
162
- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
163
- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapV);
164
-
165
150
// Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
166
151
//rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
167
152
unifiedLayer
@@ -173,23 +158,14 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
173
158
state .localSize , // currently 128, should be variable of global nEmbHead
174
159
nEmbdHead , // for normalization
175
160
config .rmsNormEps ()) // for normalization
176
- // .task("rmsnormFinalNormalization_Qcur",
177
- // Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
178
- // context,
179
- // state.tempQcur, // output
180
- // config.numberOfHeads(),
181
- // nEmbdHead,
182
- // config.rmsNormEps())
183
161
.task ("rmsnormMapIndexInPlace_Qcur" ,
184
162
Qwen3Kernels ::rmsnormMapIndexInPlaceWithParallelOffset ,
185
163
context ,
186
164
state .wrapQ , // output
187
165
weights .rms_att_QNormLayered [layerIndex ],
188
166
nEmbdHead ,
189
167
state .tempQcur );
190
- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
191
- // unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
192
- //
168
+
193
169
// Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
194
170
//rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
195
171
unifiedLayer
@@ -201,24 +177,13 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
201
177
state .localSize , // currently 128, should be variable of global nEmbHead
202
178
nEmbdHead , // for normalization
203
179
config .rmsNormEps ()) // for normalization
204
- // .task("rmsnormFinalNormalization_Kcur",
205
- // Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
206
- // context,
207
- // state.tempKcur, // output
208
- // config.numberOfKeyValueHeads(),
209
- // nEmbdHead,
210
- // config.rmsNormEps())
211
180
.task ("rmsnormMapIndexInPlace_Kcur" ,
212
181
Qwen3Kernels ::rmsnormMapIndexInPlaceWithParallelOffset ,
213
182
context ,
214
183
state .wrapK , // output
215
184
weights .rms_att_KNormLayered [layerIndex ],
216
185
nEmbdHead ,
217
186
state .tempKcur );
218
- // dbg copy out
219
- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
220
- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
221
- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapV);
222
187
223
188
// rope rotation task graph
224
189
unifiedLayer .task ("ropeRotation" ,
@@ -230,10 +195,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
230
195
config .numberOfKeyValueHeads (),
231
196
nEmbdHead );
232
197
233
- // dbg copy out
234
- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
235
- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
236
-
237
198
unifiedLayer .task ("copyToCaches" ,
238
199
TransformerComputeKernelsLayered ::copyToCache ,
239
200
state .wrapKeyCache , // out
@@ -245,7 +206,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
245
206
layerIndex ,
246
207
config .contextLength ());
247
208
248
- // global size = numberOfHeads * 8 = 16 * 8 = 128
249
209
unifiedLayer .task ("parallel-attention" ,
250
210
TransformerComputeKernelsLayered ::processHeadsFlashAttentionOpt ,
251
211
context ,
@@ -261,7 +221,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
261
221
layerIndex ,
262
222
config .contextLength ());
263
223
264
- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
265
224
unifiedLayer .task ("matmul1" , Qwen3Kernels ::matrixVectorGenericWithResidual ,
266
225
context ,
267
226
state .wrapXb , // vector
@@ -271,7 +230,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
271
230
config .dim (), // dim0 = 1024
272
231
LOCAL_WORK_GROUP_SIZE_ALLOC );
273
232
274
- //unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX);
275
233
unifiedLayer .task ("reductionsOneBlockFFN" , TransformerComputeKernelsLayered ::reductionOneBlockWithLayer ,
276
234
context , state .tempFFN , state .wrapX , config .dim (), config .rmsNormEps (), state .localSize )
277
235
.task ("reductionFinalNormalizationFFN" , TransformerComputeKernelsLayered ::reductionFinalNormalization , context , state .tempFFN ,
@@ -283,7 +241,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
283
241
state .wrapXb , state .wrapHb , weights .w1Layered [layerIndex ], weights .w3Layered [layerIndex ], config .dim (), config .hiddenDim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
284
242
.task ("projectionTwo" , TransformerComputeKernelsLayered ::matrixVectorGenericWithResidual , context ,
285
243
state .wrapHb , state .wrapX , weights .w2Layered [layerIndex ], config .hiddenDim (), config .dim (), LOCAL_WORK_GROUP_SIZE_ALLOC )
286
- //.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX)
287
244
.persistOnDevice (
288
245
state .wrapX
289
246
);
@@ -295,14 +252,12 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
295
252
.consumeFromDevice (lastUnifiedLayer .getTaskGraphName (),
296
253
state .wrapX
297
254
)
298
- //.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX)
299
255
.transferToDevice (DataTransferMode .EVERY_EXECUTION ,
300
256
state .tempLogits ,
301
257
state .wrapLogits
302
258
)
303
259
.transferToDevice (DataTransferMode .FIRST_EXECUTION ,
304
260
context ,
305
- //state.wrapLogits,
306
261
weights .wclsHalfFloat ,
307
262
weights .rms_final_weight_as_floatArray
308
263
)
@@ -313,13 +268,8 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
313
268
config .dim (),
314
269
config .rmsNormEps (),
315
270
state .localSize )
316
- // .transferToHost(DataTransferMode.EVERY_EXECUTION, state.tempLogits)
317
- // .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX)
318
- // .task("reductionFinalNormalizationLogits" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits,
319
- // config.dim(), config.rmsNormEps())
320
271
.task ("mapContextLogits" , TransformerComputeKernels ::reductionOneBlock2WithLogits , context , state .wrapX ,
321
272
weights .rms_final_weight_as_floatArray , state .tempLogits );
322
- //.transferToHost(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
323
273
logits = configureQuantizedMatrixVectorFinalWeight (logits );
324
274
logits .transferToHost (DataTransferMode .EVERY_EXECUTION , state .wrapLogits );
325
275
taskGraphs .add (logits .snapshot ());
@@ -357,25 +307,13 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
357
307
curWorker .setLocalWork (128 , 1 , 1 ); // Set local work size to 256 (standard efficient size)
358
308
359
309
// Qcur
360
- // config.numberOfHeads() = 16
361
- // nEmbdHead = 128
362
- // total = 2048
363
310
WorkerGrid qCurWorker = new WorkerGrid1D (config .numberOfHeads () * nEmbdHead );
364
311
qCurWorker .setLocalWork (nEmbdHead , 1 , 1 );
365
312
366
- // WorkerGrid qCurWorker2 = new WorkerGrid1D(config.numberOfHeads());
367
- // qCurWorker2.setLocalWork(1, 1, 1);
368
-
369
313
// Kcur
370
- // config.numberOfKeyValueHeads() = 8
371
- // nEmbdHead = 128
372
- // total = 1024
373
314
WorkerGrid kCurWorker = new WorkerGrid1D (config .numberOfKeyValueHeads () * nEmbdHead );
374
315
kCurWorker .setLocalWork (nEmbdHead , 1 , 1 );
375
316
376
- // WorkerGrid kCurWorker2 = new WorkerGrid1D(config.numberOfKeyValueHeads());
377
- // kCurWorker2.setLocalWork(1, 1, 1);
378
-
379
317
int h = config .numberOfHeads ();
380
318
int ic = nEmbdHead / 2 ;
381
319
WorkerGrid ropeWorker = new WorkerGrid2D (h , ic );
@@ -384,13 +322,12 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
384
322
385
323
WorkerGrid copyToCachesWorker = new WorkerGrid1D (nEmbdGqa );
386
324
copyToCachesWorker .setGlobalWork (nEmbdGqa , 1 , 1 );
387
- copyToCachesWorker .setLocalWork (128 , 1 , 1 ); // Set local work size to 32 (for copying to caches)
325
+ copyToCachesWorker .setLocalWork (128 , 1 , 1 );
388
326
389
327
// Parallel attention worker configuration
390
- WorkerGrid parallelAttentionWorker = new WorkerGrid1D (config .numberOfHeads ()); // qwen ok
391
- // the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4
328
+ WorkerGrid parallelAttentionWorker = new WorkerGrid1D (config .numberOfHeads ());
392
329
parallelAttentionWorker .setGlobalWork (config .numberOfHeads () * 32 , 1 , 1 );
393
- parallelAttentionWorker .setLocalWork (32 , 1 , 1 ); // Set local work size to 4 (for parallel attention)
330
+ parallelAttentionWorker .setLocalWork (32 , 1 , 1 );
394
331
395
332
int matmul1Global = config .dim () * LOCAL_WORK_GROUP_SIZE_ALLOC ;
396
333
WorkerGrid matmul1Worker = new WorkerGrid1D (matmul1Global );
@@ -408,7 +345,6 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
408
345
gridScheduler .addWorkerGrid ("activationUpdate.updateX" , singleWorker );
409
346
for (int i = 0 ; i < config .numberOfLayers (); i ++) {
410
347
gridScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlock" , rmsNormWorker );
411
- //gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization", rmsNormWorker);
412
348
gridScheduler .addWorkerGrid ("layer_" + i + ".mapContext" , rmsNormWorker );
413
349
414
350
gridScheduler .addWorkerGrid ("layer_" + i + ".qmatmul" , matmulQRowMajorWorker );
@@ -417,20 +353,17 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
417
353
418
354
// Qcur
419
355
gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormReduction_Qcur" , qCurWorker );
420
- //gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Qcur", qCurWorker2);
421
356
gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormMapIndexInPlace_Qcur" , qCurWorker );
422
357
423
358
// Kcur
424
359
gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormReduction_Kcur" , kCurWorker );
425
- //gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Kcur", kCurWorker2);
426
360
gridScheduler .addWorkerGrid ("layer_" + i + ".rmsnormMapIndexInPlace_Kcur" , kCurWorker );
427
361
428
362
gridScheduler .addWorkerGrid ("layer_" + i + ".ropeRotation" , ropeWorker );
429
363
gridScheduler .addWorkerGrid ("layer_" + i + ".copyToCaches" , copyToCachesWorker );
430
364
gridScheduler .addWorkerGrid ("layer_" + i + ".parallel-attention" , parallelAttentionWorker );
431
365
gridScheduler .addWorkerGrid ("layer_" + i + ".matmul1" , matmul1Worker );
432
366
gridScheduler .addWorkerGrid ("layer_" + i + ".reductionsOneBlockFFN" , rmsNormWorker );
433
- //gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalizationFFN", rmsNormWorker);
434
367
gridScheduler .addWorkerGrid ("layer_" + i + ".mapContextFFN" , rmsNormWorker );
435
368
gridScheduler .addWorkerGrid ("layer_" + i + ".fused_ffn_w1_w3" , fusedFFNW1W3Worker );
436
369
gridScheduler .addWorkerGrid ("layer_" + i + ".projectionTwo" , projectionTwoWorker );
0 commit comments