@@ -47,6 +47,41 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create(
47
47
atomic_store_explicit (& context -> ref_count , 1 , memory_order_relaxed );
48
48
context -> max_tokens = context_length ;
49
49
50
+ // Activation buffers
51
+ status = gptoss_metal_buffer_create (& model -> device , model -> max_batch_tokens * model -> embedding_dim * sizeof (float ), NULL , & context -> residual_activation_buffer );
52
+ if (status != gptoss_status_success ) {
53
+ goto cleanup ;
54
+ }
55
+ status = gptoss_metal_buffer_create (& model -> device , model -> max_batch_tokens * model -> embedding_dim * sizeof (float ), NULL , & context -> rmsnorm_activation_buffer );
56
+ if (status != gptoss_status_success ) {
57
+ goto cleanup ;
58
+ }
59
+ status = gptoss_metal_buffer_create (& model -> device , model -> max_batch_tokens * model -> head_dim * (model -> num_heads + 2 * model -> num_kv_heads ) * sizeof (float ), NULL , & context -> qkv_activation_buffer );
60
+ if (status != gptoss_status_success ) {
61
+ goto cleanup ;
62
+ }
63
+ status = gptoss_metal_buffer_create (& model -> device , model -> max_batch_tokens * model -> head_dim * model -> num_heads * sizeof (float ), NULL , & context -> sdpa_activation_buffer );
64
+ if (status != gptoss_status_success ) {
65
+ goto cleanup ;
66
+ }
67
+ status = gptoss_metal_buffer_create (& model -> device , model -> max_batch_tokens * model -> num_experts * sizeof (float ), NULL , & context -> gate_activation_buffer );
68
+ if (status != gptoss_status_success ) {
69
+ goto cleanup ;
70
+ }
71
+ status = gptoss_metal_buffer_create (& model -> device , model -> max_batch_tokens * model -> num_experts * sizeof (struct gptoss_expert_prediction ), NULL , & context -> expert_activation_buffer );
72
+ if (status != gptoss_status_success ) {
73
+ goto cleanup ;
74
+ }
75
+ status = gptoss_metal_buffer_create (& model -> device , model -> max_batch_tokens * model -> num_active_experts * model -> mlp_dim * sizeof (float ), NULL , & context -> swiglu_activation_buffer );
76
+ if (status != gptoss_status_success ) {
77
+ goto cleanup ;
78
+ }
79
+ status = gptoss_metal_buffer_create (& model -> device , model -> max_batch_tokens * model -> num_active_experts * model -> embedding_dim * sizeof (float ), NULL , & context -> moe_activation_buffer );
80
+ if (status != gptoss_status_success ) {
81
+ goto cleanup ;
82
+ }
83
+
84
+ // Input/output buffers
50
85
status = gptoss_metal_buffer_create (& model -> device , context_length * sizeof (uint32_t ), NULL , & context -> token_buffer );
51
86
if (status != gptoss_status_success ) {
52
87
goto cleanup ;
@@ -73,7 +108,11 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create(
73
108
}
74
109
75
110
context -> kvcache_size = context -> kvcache_buffer .size ;
76
- context -> allocation_size = context -> token_buffer .size + context -> kvcache_buffer .size + context -> score_buffer .size + context -> argmax_buffer .size ;
111
+ context -> allocation_size =
112
+ context -> residual_activation_buffer .size + context -> rmsnorm_activation_buffer .size +
113
+ context -> qkv_activation_buffer .size + context -> sdpa_activation_buffer .size +
114
+ context -> gate_activation_buffer .size + context -> expert_activation_buffer .size + context -> swiglu_activation_buffer .size + context -> moe_activation_buffer .size +
115
+ context -> token_buffer .size + context -> kvcache_buffer .size + context -> score_buffer .size + context -> argmax_buffer .size ;
77
116
78
117
context -> model = model ;
79
118
gptoss_model_retain (model );
@@ -139,7 +178,7 @@ static enum gptoss_status process_batch(
139
178
(context -> num_tokens - context -> num_batch_tokens ) * sizeof (uint32_t ),
140
179
& model -> shared_weight_buffer ,
141
180
/*weight_offset=*/ 0 ,
142
- & model -> residual_activation_buffer ,
181
+ & context -> residual_activation_buffer ,
143
182
/*output_offset=*/ 0 ,
144
183
/*num_tokens=*/ context -> num_batch_tokens ,
145
184
/*num_channels=*/ model -> embedding_dim );
@@ -154,11 +193,11 @@ static enum gptoss_status process_batch(
154
193
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm (
155
194
& command_buffer ,
156
195
& model -> f32_bf16w_rmsnorm_fn ,
157
- & model -> residual_activation_buffer ,
196
+ & context -> residual_activation_buffer ,
158
197
/*input_offset=*/ 0 ,
159
198
& model -> shared_weight_buffer ,
160
199
/*weight_offset=*/ model -> attn_rmsnorm_gain_offset + model -> per_block_shared_weights_size * n ,
161
- & model -> rmsnorm_activation_buffer ,
200
+ & context -> rmsnorm_activation_buffer ,
162
201
/*output_offset=*/ 0 ,
163
202
/*num_tokens=*/ context -> num_batch_tokens ,
164
203
/*num_channels=*/ model -> embedding_dim ,
@@ -171,13 +210,13 @@ static enum gptoss_status process_batch(
171
210
& command_buffer ,
172
211
& model -> f32_bf16w_matmul_fn ,
173
212
/*threadgroup_size=*/ 256 ,
174
- & model -> rmsnorm_activation_buffer ,
213
+ & context -> rmsnorm_activation_buffer ,
175
214
/*input_offset=*/ 0 ,
176
215
& model -> shared_weight_buffer ,
177
216
/*weight_offset=*/ model -> attn_qkv_weight_offset + model -> per_block_shared_weights_size * n ,
178
217
& model -> shared_weight_buffer ,
179
218
/*bias_offset=*/ model -> attn_qkv_bias_offset + model -> per_block_shared_weights_size * n ,
180
- & model -> qkv_activation_buffer ,
219
+ & context -> qkv_activation_buffer ,
181
220
/*output_offset=*/ 0 ,
182
221
/*num_tokens=*/ context -> num_batch_tokens ,
183
222
/*num_cols=*/ model -> embedding_dim ,
@@ -191,7 +230,7 @@ static enum gptoss_status process_batch(
191
230
& command_buffer ,
192
231
& model -> f32_rope_fn ,
193
232
/*threadgroup_size=*/ 32 ,
194
- & model -> qkv_activation_buffer ,
233
+ & context -> qkv_activation_buffer ,
195
234
model -> rope_theta ,
196
235
model -> interpolation_scale ,
197
236
model -> yarn_offset ,
@@ -209,7 +248,7 @@ static enum gptoss_status process_batch(
209
248
for (uint32_t t = 0 ; t < context -> num_batch_tokens ; t ++ ) {
210
249
status = gptoss_metal_command_buffer_encode_copy_buffer (
211
250
& command_buffer ,
212
- & model -> qkv_activation_buffer ,
251
+ & context -> qkv_activation_buffer ,
213
252
/*input_offset=*/ (t * attn_qkv_dim + model -> num_heads * model -> head_dim ) * sizeof (float ),
214
253
& context -> kvcache_buffer ,
215
254
/*output_offset=*/ (n * context -> max_tokens + context -> num_kv_tokens + t ) * 2 * model -> num_kv_heads * model -> head_dim * sizeof (float ),
@@ -223,15 +262,15 @@ static enum gptoss_status process_batch(
223
262
status = gptoss_metal_command_buffer_encode_launch_f32_sdpa (
224
263
& command_buffer ,
225
264
& model -> f32_sdpa_q8_d64_fn ,
226
- & model -> qkv_activation_buffer ,
265
+ & context -> qkv_activation_buffer ,
227
266
/*q_offset=*/ attn_qkv_dim * (context -> num_batch_tokens - num_output_tokens ) * sizeof (float ),
228
267
& context -> kvcache_buffer ,
229
268
/*k_offset=*/ n * context -> max_tokens * 2 * model -> num_kv_heads * model -> head_dim * sizeof (float ),
230
269
& context -> kvcache_buffer ,
231
270
/*v_offset=*/ (n * context -> max_tokens * 2 + 1 ) * model -> num_kv_heads * model -> head_dim * sizeof (float ),
232
271
& model -> shared_weight_buffer ,
233
272
/*s_offset=*/ model -> attn_sdpa_sink_offset + model -> per_block_shared_weights_size * n ,
234
- & model -> sdpa_activation_buffer , /*output_offset=*/ 0 ,
273
+ & context -> sdpa_activation_buffer , /*output_offset=*/ 0 ,
235
274
/*window=*/ n % 2 == 0 ? model -> attention_window : UINT32_MAX ,
236
275
num_output_tokens , context -> num_kv_tokens + (context -> num_batch_tokens - num_output_tokens ),
237
276
model -> num_heads , model -> num_kv_heads , model -> head_dim );
@@ -243,13 +282,13 @@ static enum gptoss_status process_batch(
243
282
& command_buffer ,
244
283
& model -> f32_bf16w_matmul_fn ,
245
284
/*threadgroup_size=*/ 256 ,
246
- & model -> sdpa_activation_buffer ,
285
+ & context -> sdpa_activation_buffer ,
247
286
/*input_offset=*/ 0 ,
248
287
& model -> shared_weight_buffer ,
249
288
/*weight_offset=*/ model -> attn_out_weight_offset + model -> per_block_shared_weights_size * n ,
250
289
& model -> shared_weight_buffer ,
251
290
/*bias_offset=*/ model -> attn_out_bias_offset + model -> per_block_shared_weights_size * n ,
252
- & model -> residual_activation_buffer ,
291
+ & context -> residual_activation_buffer ,
253
292
/*output_offset=*/ model -> embedding_dim * (context -> num_batch_tokens - num_output_tokens ) * sizeof (float ),
254
293
/*num_tokens=*/ num_output_tokens ,
255
294
/*num_cols=*/ model -> num_heads * model -> head_dim ,
@@ -262,11 +301,11 @@ static enum gptoss_status process_batch(
262
301
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm (
263
302
& command_buffer ,
264
303
& model -> f32_bf16w_rmsnorm_fn ,
265
- & model -> residual_activation_buffer ,
304
+ & context -> residual_activation_buffer ,
266
305
/*input_offset=*/ model -> embedding_dim * (context -> num_batch_tokens - num_output_tokens ) * sizeof (float ),
267
306
& model -> shared_weight_buffer ,
268
307
/*weight_offset=*/ model -> mlp_rmsnorm_gain_offset + model -> per_block_shared_weights_size * n ,
269
- & model -> rmsnorm_activation_buffer ,
308
+ & context -> rmsnorm_activation_buffer ,
270
309
/*output_offset=*/ 0 ,
271
310
num_output_tokens ,
272
311
model -> embedding_dim ,
@@ -280,13 +319,13 @@ static enum gptoss_status process_batch(
280
319
& command_buffer ,
281
320
& model -> f32_bf16w_matmul_fn ,
282
321
/*threadgroup_size=*/ 256 ,
283
- & model -> rmsnorm_activation_buffer ,
322
+ & context -> rmsnorm_activation_buffer ,
284
323
/*input_offset=*/ 0 ,
285
324
& model -> shared_weight_buffer ,
286
325
/*weight_offset=*/ model -> mlp_gate_weight_offset + model -> per_block_shared_weights_size * n ,
287
326
& model -> shared_weight_buffer ,
288
327
/*bias_offset=*/ model -> mlp_gate_bias_offset + model -> per_block_shared_weights_size * n ,
289
- & model -> gate_activation_buffer ,
328
+ & context -> gate_activation_buffer ,
290
329
/*output_offset=*/ 0 ,
291
330
/*num_tokens=*/ num_output_tokens ,
292
331
/*num_cols=*/ model -> embedding_dim ,
@@ -303,8 +342,8 @@ static enum gptoss_status process_batch(
303
342
status = gptoss_metal_command_buffer_encode_launch_f32_topk (
304
343
& command_buffer ,
305
344
& model -> f32_topk_softmax_e32_k4_fn ,
306
- & model -> gate_activation_buffer , /*input_offset=*/ 0 ,
307
- & model -> expert_activation_buffer , /*output_offset=*/ 0 ,
345
+ & context -> gate_activation_buffer , /*input_offset=*/ 0 ,
346
+ & context -> expert_activation_buffer , /*output_offset=*/ 0 ,
308
347
num_output_tokens ,
309
348
model -> num_experts ,
310
349
model -> num_active_experts );
@@ -314,8 +353,8 @@ static enum gptoss_status process_batch(
314
353
status = gptoss_metal_command_buffer_encode_launch_f32_topk (
315
354
& command_buffer ,
316
355
& model -> f32_topk_softmax_e128_k4_fn ,
317
- & model -> gate_activation_buffer , /*input_offset=*/ 0 ,
318
- & model -> expert_activation_buffer , /*output_offset=*/ 0 ,
356
+ & context -> gate_activation_buffer , /*input_offset=*/ 0 ,
357
+ & context -> expert_activation_buffer , /*output_offset=*/ 0 ,
319
358
num_output_tokens ,
320
359
model -> num_experts ,
321
360
model -> num_active_experts );
@@ -334,12 +373,12 @@ static enum gptoss_status process_batch(
334
373
& command_buffer ,
335
374
& model -> f32_mf4w_moe_matmul_swiglu_fn ,
336
375
/*threadgroup_size=*/ 512 ,
337
- & model -> rmsnorm_activation_buffer , /*input_offset=*/ 0 ,
338
- & model -> expert_activation_buffer , /*expert_offset=*/ 0 ,
376
+ & context -> rmsnorm_activation_buffer , /*input_offset=*/ 0 ,
377
+ & context -> expert_activation_buffer , /*expert_offset=*/ 0 ,
339
378
& model -> block_weight_buffers [n ], /*weight_block_offset=*/ 0 ,
340
379
& model -> block_weight_buffers [n ], /*weight_scale_offset=*/ model -> mlp_swiglu_scale_offset ,
341
380
& model -> block_weight_buffers [n ], /*bias_offset=*/ model -> mlp_swiglu_bias_offset ,
342
- & model -> swiglu_activation_buffer , /*output_offset=*/ 0 ,
381
+ & context -> swiglu_activation_buffer , /*output_offset=*/ 0 ,
343
382
model -> swiglu_limit ,
344
383
model -> per_expert_block_weight_size ,
345
384
num_output_tokens ,
@@ -355,12 +394,12 @@ static enum gptoss_status process_batch(
355
394
& command_buffer ,
356
395
& model -> f32_mf4w_moe_matmul_fn ,
357
396
/*threadgroup_size=*/ 512 ,
358
- & model -> swiglu_activation_buffer , /*input_offset=*/ 0 ,
359
- & model -> expert_activation_buffer , /*expert_offset=*/ 0 ,
397
+ & context -> swiglu_activation_buffer , /*input_offset=*/ 0 ,
398
+ & context -> expert_activation_buffer , /*expert_offset=*/ 0 ,
360
399
& model -> block_weight_buffers [n ], /*weight_block_offset=*/ model -> mlp_out_block_offset ,
361
400
& model -> block_weight_buffers [n ], /*weight_scale_offset=*/ model -> mlp_out_scale_offset ,
362
401
& model -> block_weight_buffers [n ], /*bias_offset=*/ model -> mlp_out_bias_offset ,
363
- & model -> moe_activation_buffer , /*output_offset=*/ 0 ,
402
+ & context -> moe_activation_buffer , /*output_offset=*/ 0 ,
364
403
model -> per_expert_block_weight_size ,
365
404
num_output_tokens ,
366
405
model -> num_active_experts ,
@@ -376,11 +415,11 @@ static enum gptoss_status process_batch(
376
415
& model -> f32_accumulate_e4_fn ,
377
416
/*threadgroup_size=*/ 256 ,
378
417
model -> max_threadgroups ,
379
- & model -> moe_activation_buffer ,
418
+ & context -> moe_activation_buffer ,
380
419
/*input_offset=*/ 0 ,
381
- & model -> expert_activation_buffer ,
420
+ & context -> expert_activation_buffer ,
382
421
/*expert_offset=*/ 0 ,
383
- & model -> residual_activation_buffer ,
422
+ & context -> residual_activation_buffer ,
384
423
/*output_offset=*/ model -> embedding_dim * (context -> num_batch_tokens - num_output_tokens ) * sizeof (float ),
385
424
model -> embedding_dim ,
386
425
num_output_tokens ,
@@ -395,11 +434,11 @@ static enum gptoss_status process_batch(
395
434
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm (
396
435
& command_buffer ,
397
436
& model -> f32_bf16w_rmsnorm_fn ,
398
- & model -> residual_activation_buffer ,
437
+ & context -> residual_activation_buffer ,
399
438
/*input_offset=*/ model -> embedding_dim * (context -> num_batch_tokens - num_output_tokens ) * sizeof (float ),
400
439
& model -> shared_weight_buffer ,
401
440
/*weight_offset=*/ model -> rmsnorm_weight_offset ,
402
- & model -> rmsnorm_activation_buffer ,
441
+ & context -> rmsnorm_activation_buffer ,
403
442
/*output_offset=*/ 0 ,
404
443
/*num_tokens=*/ num_output_tokens ,
405
444
/*num_channels=*/ model -> embedding_dim ,
@@ -424,7 +463,7 @@ static enum gptoss_status process_batch(
424
463
& model -> f32_bf16w_unembedding_fn ,
425
464
/*threadgroup_size=*/ 256 ,
426
465
model -> max_threadgroups ,
427
- & model -> rmsnorm_activation_buffer ,
466
+ & context -> rmsnorm_activation_buffer ,
428
467
/*input_offset=*/ 0 ,
429
468
& model -> shared_weight_buffer ,
430
469
/*weight_offset=*/ model -> unembedding_weight_offset ,
@@ -700,6 +739,17 @@ enum gptoss_status GPTOSS_ABI gptoss_context_release(
700
739
{
701
740
if (context != NULL ) {
702
741
if (atomic_fetch_sub_explicit (& context -> ref_count , 1 , memory_order_acq_rel ) == 1 ) {
742
+ // Activation buffers
743
+ gptoss_metal_buffer_release (& context -> residual_activation_buffer );
744
+ gptoss_metal_buffer_release (& context -> rmsnorm_activation_buffer );
745
+ gptoss_metal_buffer_release (& context -> qkv_activation_buffer );
746
+ gptoss_metal_buffer_release (& context -> sdpa_activation_buffer );
747
+ gptoss_metal_buffer_release (& context -> gate_activation_buffer );
748
+ gptoss_metal_buffer_release (& context -> expert_activation_buffer );
749
+ gptoss_metal_buffer_release (& context -> swiglu_activation_buffer );
750
+ gptoss_metal_buffer_release (& context -> moe_activation_buffer );
751
+
752
+ // Input/output buffers
703
753
gptoss_metal_buffer_release (& context -> token_buffer );
704
754
gptoss_metal_buffer_release (& context -> score_buffer );
705
755
gptoss_metal_buffer_release (& context -> prob_buffer );
0 commit comments