Skip to content

Commit 5b5e9e3

Browse files
Cleanup Qwen3TornadoVMLayerPlanner
1 parent 6224ae0 commit 5b5e9e3

File tree

1 file changed

+4
-71
lines changed

1 file changed

+4
-71
lines changed

src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java

Lines changed: 4 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import com.example.auxiliary.Tuple2;
44
import com.example.inference.state.Qwen3State;
5-
import com.example.inference.state.State;
65
import com.example.inference.weights.tornado.Qwen3TornadoWeights;
76
import com.example.model.Model;
87
import com.example.model.qwen3.Qwen3Configuration;
@@ -109,8 +108,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
109108
config.dim(),
110109
config.rmsNormEps(),
111110
state.localSize)
112-
//.task("reductionFinalNormalization" , TransformerComputeKernelsLayered::reductionFinalNormalization, context,
113-
//state.temp, config.dim(), config.rmsNormEps())
114111
.task("mapContext",
115112
TransformerComputeKernelsLayered::reductionOneBlock2WithLayer,
116113
context,
@@ -119,16 +116,9 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
119116
weights.rms_att_weightLayered[layerIndex],
120117
state.temp);
121118

122-
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
123-
124-
// // dbg copy out
125-
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.temp);
126-
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
127-
128119
int qDim0 = nEmbdHeadK * config.numberOfHeads();
129120
int kvDim0 = nEmbdGqa;
130121
int qkvDim1 = config.dim();
131-
//qkvMatmuls = new TaskGraph("qkvMatmuls_layer_" + layerIndex);
132122
unifiedLayer.task("qmatmul",
133123
TransformerComputeKernelsLayered::matrixVectorGeneric,
134124
context,
@@ -157,11 +147,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
157147
kvDim0,
158148
LOCAL_WORK_GROUP_SIZE_ALLOC);
159149

160-
// dbg copy out
161-
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
162-
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
163-
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapV);
164-
165150
// Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
166151
//rmsnorm(state.q, state.q, weights.attnQNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
167152
unifiedLayer
@@ -173,23 +158,14 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
173158
state.localSize, // currently 128, should be variable of global nEmbHead
174159
nEmbdHead, // for normalization
175160
config.rmsNormEps()) // for normalization
176-
// .task("rmsnormFinalNormalization_Qcur",
177-
// Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
178-
// context,
179-
// state.tempQcur, // output
180-
// config.numberOfHeads(),
181-
// nEmbdHead,
182-
// config.rmsNormEps())
183161
.task("rmsnormMapIndexInPlace_Qcur",
184162
Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset,
185163
context,
186164
state.wrapQ, // output
187165
weights.rms_att_QNormLayered[layerIndex],
188166
nEmbdHead,
189167
state.tempQcur);
190-
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
191-
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
192-
//
168+
193169
// Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
194170
//rmsnorm(state.k, state.k, weights.attnKNorm[curLayer], i * nEmbdHead, nEmbdHead, config.rmsNormEps());
195171
unifiedLayer
@@ -201,24 +177,13 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
201177
state.localSize, // currently 128, should be variable of global nEmbHead
202178
nEmbdHead, // for normalization
203179
config.rmsNormEps()) // for normalization
204-
// .task("rmsnormFinalNormalization_Kcur",
205-
// Qwen3Kernels::rmsnormFinalNormalizationWithParallelOffset,
206-
// context,
207-
// state.tempKcur, // output
208-
// config.numberOfKeyValueHeads(),
209-
// nEmbdHead,
210-
// config.rmsNormEps())
211180
.task("rmsnormMapIndexInPlace_Kcur",
212181
Qwen3Kernels::rmsnormMapIndexInPlaceWithParallelOffset,
213182
context,
214183
state.wrapK, // output
215184
weights.rms_att_KNormLayered[layerIndex],
216185
nEmbdHead,
217186
state.tempKcur);
218-
// dbg copy out
219-
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
220-
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
221-
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapV);
222187

223188
// rope rotation task graph
224189
unifiedLayer.task("ropeRotation",
@@ -230,10 +195,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
230195
config.numberOfKeyValueHeads(),
231196
nEmbdHead);
232197

233-
// dbg copy out
234-
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
235-
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
236-
237198
unifiedLayer.task("copyToCaches",
238199
TransformerComputeKernelsLayered::copyToCache,
239200
state.wrapKeyCache, // out
@@ -245,7 +206,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
245206
layerIndex,
246207
config.contextLength());
247208

248-
// global size = numberOfHeads * 8 = 16 * 8 = 128
249209
unifiedLayer.task("parallel-attention",
250210
TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt,
251211
context,
@@ -261,7 +221,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
261221
layerIndex,
262222
config.contextLength());
263223

264-
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
265224
unifiedLayer.task("matmul1", Qwen3Kernels::matrixVectorGenericWithResidual,
266225
context,
267226
state.wrapXb, // vector
@@ -271,7 +230,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
271230
config.dim(), // dim0 = 1024
272231
LOCAL_WORK_GROUP_SIZE_ALLOC);
273232

274-
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX);
275233
unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
276234
context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
277235
.task("reductionFinalNormalizationFFN" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempFFN,
@@ -283,7 +241,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
283241
state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
284242
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
285243
state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
286-
//.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX)
287244
.persistOnDevice(
288245
state.wrapX
289246
);
@@ -295,14 +252,12 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
295252
.consumeFromDevice(lastUnifiedLayer.getTaskGraphName(),
296253
state.wrapX
297254
)
298-
//.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX)
299255
.transferToDevice(DataTransferMode.EVERY_EXECUTION,
300256
state.tempLogits,
301257
state.wrapLogits
302258
)
303259
.transferToDevice(DataTransferMode.FIRST_EXECUTION,
304260
context,
305-
//state.wrapLogits,
306261
weights.wclsHalfFloat,
307262
weights.rms_final_weight_as_floatArray
308263
)
@@ -313,13 +268,8 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
313268
config.dim(),
314269
config.rmsNormEps(),
315270
state.localSize)
316-
// .transferToHost(DataTransferMode.EVERY_EXECUTION, state.tempLogits)
317-
// .transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX)
318-
// .task("reductionFinalNormalizationLogits" , TransformerComputeKernelsLayered::reductionFinalNormalization, context, state.tempLogits,
319-
// config.dim(), config.rmsNormEps())
320271
.task("mapContextLogits", TransformerComputeKernels::reductionOneBlock2WithLogits, context, state.wrapX,
321272
weights.rms_final_weight_as_floatArray, state.tempLogits);
322-
//.transferToHost(DataTransferMode.EVERY_EXECUTION, state.tempLogits);
323273
logits = configureQuantizedMatrixVectorFinalWeight(logits);
324274
logits.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapLogits);
325275
taskGraphs.add(logits.snapshot());
@@ -357,25 +307,13 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
357307
curWorker.setLocalWork(128, 1, 1); // Set local work size to 256 (standard efficient size)
358308

359309
// Qcur
360-
// config.numberOfHeads() = 16
361-
// nEmbdHead = 128
362-
// total = 2048
363310
WorkerGrid qCurWorker = new WorkerGrid1D(config.numberOfHeads() * nEmbdHead);
364311
qCurWorker.setLocalWork(nEmbdHead, 1, 1);
365312

366-
// WorkerGrid qCurWorker2 = new WorkerGrid1D(config.numberOfHeads());
367-
// qCurWorker2.setLocalWork(1, 1, 1);
368-
369313
// Kcur
370-
// config.numberOfKeyValueHeads() = 8
371-
// nEmbdHead = 128
372-
// total = 1024
373314
WorkerGrid kCurWorker = new WorkerGrid1D(config.numberOfKeyValueHeads() * nEmbdHead);
374315
kCurWorker.setLocalWork(nEmbdHead, 1, 1);
375316

376-
// WorkerGrid kCurWorker2 = new WorkerGrid1D(config.numberOfKeyValueHeads());
377-
// kCurWorker2.setLocalWork(1, 1, 1);
378-
379317
int h = config.numberOfHeads();
380318
int ic = nEmbdHead / 2;
381319
WorkerGrid ropeWorker = new WorkerGrid2D(h, ic);
@@ -384,13 +322,12 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
384322

385323
WorkerGrid copyToCachesWorker = new WorkerGrid1D(nEmbdGqa);
386324
copyToCachesWorker.setGlobalWork(nEmbdGqa, 1, 1);
387-
copyToCachesWorker.setLocalWork(128, 1, 1); // Set local work size to 32 (for copying to caches)
325+
copyToCachesWorker.setLocalWork(128, 1, 1);
388326

389327
// Parallel attention worker configuration
390-
WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads()); // qwen ok
391-
// the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4
328+
WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads());
392329
parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * 32, 1, 1);
393-
parallelAttentionWorker.setLocalWork(32, 1, 1); // Set local work size to 4 (for parallel attention)
330+
parallelAttentionWorker.setLocalWork(32, 1, 1);
394331

395332
int matmul1Global = config.dim() * LOCAL_WORK_GROUP_SIZE_ALLOC;
396333
WorkerGrid matmul1Worker = new WorkerGrid1D(matmul1Global);
@@ -408,7 +345,6 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
408345
gridScheduler.addWorkerGrid("activationUpdate.updateX", singleWorker);
409346
for (int i = 0; i < config.numberOfLayers(); i++) {
410347
gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlock", rmsNormWorker);
411-
//gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalization", rmsNormWorker);
412348
gridScheduler.addWorkerGrid("layer_" + i + ".mapContext", rmsNormWorker);
413349

414350
gridScheduler.addWorkerGrid("layer_" + i + ".qmatmul", matmulQRowMajorWorker);
@@ -417,20 +353,17 @@ private GridScheduler setupQwen3GridSchedulersLayeredNonNvidia() {
417353

418354
// Qcur
419355
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Qcur", qCurWorker);
420-
//gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Qcur", qCurWorker2);
421356
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Qcur", qCurWorker);
422357

423358
// Kcur
424359
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormReduction_Kcur", kCurWorker);
425-
//gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormFinalNormalization_Kcur", kCurWorker2);
426360
gridScheduler.addWorkerGrid("layer_" + i + ".rmsnormMapIndexInPlace_Kcur", kCurWorker);
427361

428362
gridScheduler.addWorkerGrid("layer_" + i + ".ropeRotation", ropeWorker);
429363
gridScheduler.addWorkerGrid("layer_" + i + ".copyToCaches", copyToCachesWorker);
430364
gridScheduler.addWorkerGrid("layer_" + i + ".parallel-attention", parallelAttentionWorker);
431365
gridScheduler.addWorkerGrid("layer_" + i + ".matmul1", matmul1Worker);
432366
gridScheduler.addWorkerGrid("layer_" + i + ".reductionsOneBlockFFN", rmsNormWorker);
433-
//gridScheduler.addWorkerGrid("layer_" + i + ".reductionFinalNormalizationFFN", rmsNormWorker);
434367
gridScheduler.addWorkerGrid("layer_" + i + ".mapContextFFN", rmsNormWorker);
435368
gridScheduler.addWorkerGrid("layer_" + i + ".fused_ffn_w1_w3", fusedFFNW1W3Worker);
436369
gridScheduler.addWorkerGrid("layer_" + i + ".projectionTwo", projectionTwoWorker);

0 commit comments

Comments
 (0)