Skip to content

Commit 4a29063

Browse files
Cleanup dbg buffers functionality
1 parent 717257a commit 4a29063

File tree

2 files changed

+1
-109
lines changed

2 files changed

+1
-109
lines changed

src/main/java/com/example/inference/state/Qwen3State.java

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,6 @@ public final class Qwen3State extends State {
1818
public FloatArray tempQcur;
1919
public FloatArray tempKcur;
2020

21-
// dbg buffer
22-
public FloatArray dbgQ;
23-
public FloatArray dbgKeyCache;
24-
public FloatArray dbgValueCache;
25-
public FloatArray dbgX;
26-
public FloatArray dbgXb;
27-
2821
public Qwen3State(Configuration config, int batchsize) {
2922
super(config, batchsize);
3023
// Initialize Qwen3-specific field
@@ -33,20 +26,6 @@ public Qwen3State(Configuration config, int batchsize) {
3326
this.kq = ArrayFloatTensor.allocate(config.numberOfHeads(), 32, 15);
3427
this.tempQcur = new FloatArray(nEmbdHead);
3528
this.tempKcur = new FloatArray(nEmbdHead);
36-
37-
// dbg buffers
38-
int nHeadKv = qwen3config.numberOfKeyValueHeads();
39-
int nEmbdHeadK = qwen3config.numberOfHeadsKey();
40-
int nEmbdKGqa = nEmbdHeadK * nHeadKv;
41-
int nEmbdHeadV = qwen3config.numberOfHeadsValue();
42-
int nEmbdVGqa = nEmbdHeadV * nHeadKv;
43-
int nEmbdGqa = nEmbdVGqa;
44-
45-
this.dbgQ = new FloatArray(nEmbdHeadK * qwen3config.numberOfHeads());
46-
this.dbgKeyCache = new FloatArray(qwen3config.contextLength() * nEmbdGqa * qwen3config.numberOfLayers());
47-
this.dbgValueCache = new FloatArray(qwen3config.contextLength() * nEmbdGqa * qwen3config.numberOfLayers());
48-
this.dbgX = new FloatArray(config.dim());
49-
this.dbgXb = new FloatArray(nEmbdHeadK * qwen3config.numberOfHeads());
5029
}
5130

5231
@Override

src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java

Lines changed: 1 addition & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,14 @@ protected TaskGraph configureLayerDataTransfers(TaskGraph unifiedLayer, int laye
5151
context, state.wrapXb, state.wrapXb2, //
5252
state.wrapQ, state.wrapK, state.wrapV, //
5353
state.wrapKeyCache, state.wrapValueCache, //
54-
state.wrapAtt, state.wrapHb);//,
55-
// dbg buffers
56-
//state.dbgQ, state.dbgKeyCache, state.dbgValueCache, state.dbgXb, state.dbgX); //
54+
state.wrapAtt, state.wrapHb);//
5755
} else {
5856
// Subsequent layers: Consume data already on device from previous layer
5957
unifiedLayer.consumeFromDevice(context, state.wrapXb, state.wrapXb2, //
6058
state.wrapQ, state.wrapK, state.wrapV, //
6159
state.wrapKeyCache, state.wrapValueCache, //
6260
state.wrapAtt, state.wrapHb, //
6361
state.positionHolder //
64-
//state.dbgQ, state.dbgKeyCache, state.dbgValueCache, state.dbgXb, state.dbgX
6562
);
6663
}
6764
return unifiedLayer;
@@ -76,10 +73,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
7673
state.tempLogits.init(0.0f);
7774
state.wrapLogits.init(0.0f);
7875

79-
// state.dbgQ.init(0.0f);
80-
// state.dbgKeyCache.init(0.0f);
81-
// state.dbgValueCache.init(0.0f);
82-
8376
// @formatter:off
8477
TaskGraph activationUpdate = new TaskGraph("activationUpdate")
8578
.transferToDevice(DataTransferMode.EVERY_EXECUTION, state.wrapX)
@@ -108,12 +101,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
108101
weights.w3Layered[layerIndex]
109102
);
110103
unifiedLayer = configureLayerDataTransfers(unifiedLayer, layerIndex);
111-
// unifiedLayer.task("dbg_copy_out_x",
112-
// Qwen3Kernels::dbgCopy,
113-
// state.wrapX,
114-
// state.dbgX,
115-
// state.positionHolder,
116-
// layerIndex);
117104
unifiedLayer.task("reductionsOneBlock",
118105
TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
119106
context,
@@ -170,13 +157,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
170157
kvDim0,
171158
LOCAL_WORK_GROUP_SIZE_ALLOC);
172159

173-
// unifiedLayer.task("dbg_copy_out_wrapQ",
174-
// Qwen3Kernels::dbgCopy,
175-
// state.wrapQ,
176-
// state.dbgQ,
177-
// state.positionHolder,
178-
// layerIndex);
179-
180160
// dbg copy out
181161
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
182162
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
@@ -205,13 +185,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
205185
weights.rms_att_QNormLayered[layerIndex],
206186
nEmbdHead,
207187
state.tempQcur);
208-
209-
// unifiedLayer.task("dbg_copy_out_wrapQ",
210-
// Qwen3Kernels::dbgCopy,
211-
// state.wrapQ,
212-
// state.dbgQ,
213-
// state.positionHolder,
214-
// layerIndex);
215188
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
216189
// unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
217190
//
@@ -253,13 +226,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
253226
config.numberOfKeyValueHeads(),
254227
nEmbdHead);
255228

256-
// unifiedLayer.task("dbg_copy_out_wrapQ",
257-
// Qwen3Kernels::dbgCopy,
258-
// state.wrapQ,
259-
// state.dbgQ,
260-
// state.positionHolder,
261-
// layerIndex);
262-
263229
// dbg copy out
264230
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapQ);
265231
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapK);
@@ -275,27 +241,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
275241
layerIndex,
276242
config.contextLength());
277243

278-
// unifiedLayer.task("dbg_copy_out_q",
279-
// Qwen3Kernels::dbgCopy,
280-
// state.wrapQ,
281-
// state.dbgQ,
282-
// state.positionHolder,
283-
// layerIndex);
284-
//
285-
// unifiedLayer.task("dbg_copy_out_keyCache",
286-
// Qwen3Kernels::dbgCopy,
287-
// state.wrapKeyCache,
288-
// state.dbgKeyCache,
289-
// state.positionHolder,
290-
// layerIndex);
291-
//
292-
// unifiedLayer.task("dbg_copy_out_ValueCache",
293-
// Qwen3Kernels::dbgCopy,
294-
// state.wrapValueCache,
295-
// state.dbgValueCache,
296-
// state.positionHolder,
297-
// layerIndex);
298-
299244
// global size = numberOfHeads * 8 = 16 * 8 = 128
300245
unifiedLayer.task("parallel-attention",
301246
TransformerComputeKernelsLayered::processHeadsFlashAttentionOpt,
@@ -312,20 +257,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
312257
layerIndex,
313258
config.contextLength());
314259

315-
// unifiedLayer.task("dbg_copy_out_x",
316-
// Qwen3Kernels::dbgCopy,
317-
// state.wrapX,
318-
// state.dbgX,
319-
// state.positionHolder,
320-
// layerIndex);
321-
//
322-
// unifiedLayer.task("dbg_copy_out_xb",
323-
// Qwen3Kernels::dbgCopy,
324-
// state.wrapXb,
325-
// state.dbgXb,
326-
// state.positionHolder,
327-
// layerIndex);
328-
329260
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapXb);
330261
unifiedLayer.task("matmul1", Qwen3Kernels::matrixVectorGenericWithResidual,
331262
context,
@@ -336,13 +267,6 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
336267
config.dim(), // dim0 = 1024
337268
LOCAL_WORK_GROUP_SIZE_ALLOC);
338269

339-
// unifiedLayer.task("dbg_copy_out_x",
340-
// Qwen3Kernels::dbgCopy,
341-
// state.wrapX,
342-
// state.dbgX,
343-
// state.positionHolder,
344-
// layerIndex);
345-
346270
//unifiedLayer.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX);
347271
unifiedLayer.task("reductionsOneBlockFFN", TransformerComputeKernelsLayered::reductionOneBlockWithLayer,
348272
context, state.tempFFN, state.wrapX, config.dim(), config.rmsNormEps(), state.localSize)
@@ -351,22 +275,11 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
351275
.task("mapContextFFN", TransformerComputeKernelsLayered::reductionOneBlock2WithLayer, context, state.wrapXb,
352276
state.wrapX, weights.rms_ffn_weightLayered[layerIndex], state.tempFFN);
353277

354-
// unifiedLayer.task("dbg_copy_out_xb",
355-
// Qwen3Kernels::dbgCopy,
356-
// state.wrapXb,
357-
// state.dbgXb,
358-
// state.positionHolder,
359-
// layerIndex);
360-
361278
unifiedLayer.task("fused_ffn_w1_w3", TransformerComputeKernelsLayered::fusedFeedForwardWithSiLUAndGLUActivation, context,
362279
state.wrapXb, state.wrapHb, weights.w1Layered[layerIndex], weights.w3Layered[layerIndex], config.dim(), config.hiddenDim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
363280
.task("projectionTwo", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual, context,
364281
state.wrapHb, state.wrapX, weights.w2Layered[layerIndex], config.hiddenDim(), config.dim(), LOCAL_WORK_GROUP_SIZE_ALLOC)
365282
//.transferToHost(DataTransferMode.EVERY_EXECUTION, state.wrapX)
366-
// dbg copy out
367-
//.transferToHost(DataTransferMode.EVERY_EXECUTION, state.dbgQ, state.dbgKeyCache, state.dbgValueCache)
368-
//.transferToHost(DataTransferMode.EVERY_EXECUTION, state.dbgX)//, state.dbgXb)
369-
//.transferToHost(DataTransferMode.EVERY_EXECUTION, state.dbgValueCache)
370283
.persistOnDevice(
371284
state.wrapX
372285
);

0 commit comments

Comments
 (0)