Skip to content

Commit be8001f

Browse files
MaratyszczaDanztee
authored andcommitted
Support concurrent sampling from multiple Contexts (openai#83)
Move activation buffers from Model to Context, so they are no longer shared across contexts and multiple contexts can sample in parallel
1 parent d6c3ff7 commit be8001f

File tree

3 files changed

+94
-92
lines changed

3 files changed

+94
-92
lines changed

gpt_oss/metal/source/context.c

Lines changed: 82 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,41 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create(
4747
atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed);
4848
context->max_tokens = context_length;
4949

50+
// Activation buffers
51+
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->residual_activation_buffer);
52+
if (status != gptoss_status_success) {
53+
goto cleanup;
54+
}
55+
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->rmsnorm_activation_buffer);
56+
if (status != gptoss_status_success) {
57+
goto cleanup;
58+
}
59+
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &context->qkv_activation_buffer);
60+
if (status != gptoss_status_success) {
61+
goto cleanup;
62+
}
63+
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &context->sdpa_activation_buffer);
64+
if (status != gptoss_status_success) {
65+
goto cleanup;
66+
}
67+
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(float), NULL, &context->gate_activation_buffer);
68+
if (status != gptoss_status_success) {
69+
goto cleanup;
70+
}
71+
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &context->expert_activation_buffer);
72+
if (status != gptoss_status_success) {
73+
goto cleanup;
74+
}
75+
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer);
76+
if (status != gptoss_status_success) {
77+
goto cleanup;
78+
}
79+
status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->moe_activation_buffer);
80+
if (status != gptoss_status_success) {
81+
goto cleanup;
82+
}
83+
84+
// Input/output buffers
5085
status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);
5186
if (status != gptoss_status_success) {
5287
goto cleanup;
@@ -73,7 +108,11 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create(
73108
}
74109

75110
context->kvcache_size = context->kvcache_buffer.size;
76-
context->allocation_size = context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
111+
context->allocation_size =
112+
context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size +
113+
context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size +
114+
context->gate_activation_buffer.size + context->expert_activation_buffer.size + context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +
115+
context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
77116

78117
context->model = model;
79118
gptoss_model_retain(model);
@@ -139,7 +178,7 @@ static enum gptoss_status process_batch(
139178
(context->num_tokens - context->num_batch_tokens) * sizeof(uint32_t),
140179
&model->shared_weight_buffer,
141180
/*weight_offset=*/0,
142-
&model->residual_activation_buffer,
181+
&context->residual_activation_buffer,
143182
/*output_offset=*/0,
144183
/*num_tokens=*/context->num_batch_tokens,
145184
/*num_channels=*/model->embedding_dim);
@@ -154,11 +193,11 @@ static enum gptoss_status process_batch(
154193
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
155194
&command_buffer,
156195
&model->f32_bf16w_rmsnorm_fn,
157-
&model->residual_activation_buffer,
196+
&context->residual_activation_buffer,
158197
/*input_offset=*/0,
159198
&model->shared_weight_buffer,
160199
/*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
161-
&model->rmsnorm_activation_buffer,
200+
&context->rmsnorm_activation_buffer,
162201
/*output_offset=*/0,
163202
/*num_tokens=*/context->num_batch_tokens,
164203
/*num_channels=*/model->embedding_dim,
@@ -171,13 +210,13 @@ static enum gptoss_status process_batch(
171210
&command_buffer,
172211
&model->f32_bf16w_matmul_fn,
173212
/*threadgroup_size=*/256,
174-
&model->rmsnorm_activation_buffer,
213+
&context->rmsnorm_activation_buffer,
175214
/*input_offset=*/0,
176215
&model->shared_weight_buffer,
177216
/*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
178217
&model->shared_weight_buffer,
179218
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
180-
&model->qkv_activation_buffer,
219+
&context->qkv_activation_buffer,
181220
/*output_offset=*/0,
182221
/*num_tokens=*/context->num_batch_tokens,
183222
/*num_cols=*/model->embedding_dim,
@@ -191,7 +230,7 @@ static enum gptoss_status process_batch(
191230
&command_buffer,
192231
&model->f32_rope_fn,
193232
/*threadgroup_size=*/32,
194-
&model->qkv_activation_buffer,
233+
&context->qkv_activation_buffer,
195234
model->rope_theta,
196235
model->interpolation_scale,
197236
model->yarn_offset,
@@ -209,7 +248,7 @@ static enum gptoss_status process_batch(
209248
for (uint32_t t = 0; t < context->num_batch_tokens; t++) {
210249
status = gptoss_metal_command_buffer_encode_copy_buffer(
211250
&command_buffer,
212-
&model->qkv_activation_buffer,
251+
&context->qkv_activation_buffer,
213252
/*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float),
214253
&context->kvcache_buffer,
215254
/*output_offset=*/(n * context->max_tokens + context->num_kv_tokens + t) * 2 * model->num_kv_heads * model->head_dim * sizeof(float),
@@ -223,15 +262,15 @@ static enum gptoss_status process_batch(
223262
status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
224263
&command_buffer,
225264
&model->f32_sdpa_q8_d64_fn,
226-
&model->qkv_activation_buffer,
265+
&context->qkv_activation_buffer,
227266
/*q_offset=*/attn_qkv_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
228267
&context->kvcache_buffer,
229268
/*k_offset=*/n * context->max_tokens * 2 * model->num_kv_heads * model->head_dim * sizeof(float),
230269
&context->kvcache_buffer,
231270
/*v_offset=*/(n * context->max_tokens * 2 + 1) * model->num_kv_heads * model->head_dim * sizeof(float),
232271
&model->shared_weight_buffer,
233272
/*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
234-
&model->sdpa_activation_buffer, /*output_offset=*/0,
273+
&context->sdpa_activation_buffer, /*output_offset=*/0,
235274
/*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
236275
num_output_tokens, context->num_kv_tokens + (context->num_batch_tokens - num_output_tokens),
237276
model->num_heads, model->num_kv_heads, model->head_dim);
@@ -243,13 +282,13 @@ static enum gptoss_status process_batch(
243282
&command_buffer,
244283
&model->f32_bf16w_matmul_fn,
245284
/*threadgroup_size=*/256,
246-
&model->sdpa_activation_buffer,
285+
&context->sdpa_activation_buffer,
247286
/*input_offset=*/0,
248287
&model->shared_weight_buffer,
249288
/*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
250289
&model->shared_weight_buffer,
251290
/*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
252-
&model->residual_activation_buffer,
291+
&context->residual_activation_buffer,
253292
/*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
254293
/*num_tokens=*/num_output_tokens,
255294
/*num_cols=*/model->num_heads * model->head_dim,
@@ -262,11 +301,11 @@ static enum gptoss_status process_batch(
262301
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
263302
&command_buffer,
264303
&model->f32_bf16w_rmsnorm_fn,
265-
&model->residual_activation_buffer,
304+
&context->residual_activation_buffer,
266305
/*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
267306
&model->shared_weight_buffer,
268307
/*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
269-
&model->rmsnorm_activation_buffer,
308+
&context->rmsnorm_activation_buffer,
270309
/*output_offset=*/0,
271310
num_output_tokens,
272311
model->embedding_dim,
@@ -280,13 +319,13 @@ static enum gptoss_status process_batch(
280319
&command_buffer,
281320
&model->f32_bf16w_matmul_fn,
282321
/*threadgroup_size=*/256,
283-
&model->rmsnorm_activation_buffer,
322+
&context->rmsnorm_activation_buffer,
284323
/*input_offset=*/0,
285324
&model->shared_weight_buffer,
286325
/*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
287326
&model->shared_weight_buffer,
288327
/*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
289-
&model->gate_activation_buffer,
328+
&context->gate_activation_buffer,
290329
/*output_offset=*/0,
291330
/*num_tokens=*/num_output_tokens,
292331
/*num_cols=*/model->embedding_dim,
@@ -303,8 +342,8 @@ static enum gptoss_status process_batch(
303342
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
304343
&command_buffer,
305344
&model->f32_topk_softmax_e32_k4_fn,
306-
&model->gate_activation_buffer, /*input_offset=*/0,
307-
&model->expert_activation_buffer, /*output_offset=*/0,
345+
&context->gate_activation_buffer, /*input_offset=*/0,
346+
&context->expert_activation_buffer, /*output_offset=*/0,
308347
num_output_tokens,
309348
model->num_experts,
310349
model->num_active_experts);
@@ -314,8 +353,8 @@ static enum gptoss_status process_batch(
314353
status = gptoss_metal_command_buffer_encode_launch_f32_topk(
315354
&command_buffer,
316355
&model->f32_topk_softmax_e128_k4_fn,
317-
&model->gate_activation_buffer, /*input_offset=*/0,
318-
&model->expert_activation_buffer, /*output_offset=*/0,
356+
&context->gate_activation_buffer, /*input_offset=*/0,
357+
&context->expert_activation_buffer, /*output_offset=*/0,
319358
num_output_tokens,
320359
model->num_experts,
321360
model->num_active_experts);
@@ -334,12 +373,12 @@ static enum gptoss_status process_batch(
334373
&command_buffer,
335374
&model->f32_mf4w_moe_matmul_swiglu_fn,
336375
/*threadgroup_size=*/512,
337-
&model->rmsnorm_activation_buffer, /*input_offset=*/0,
338-
&model->expert_activation_buffer, /*expert_offset=*/0,
376+
&context->rmsnorm_activation_buffer, /*input_offset=*/0,
377+
&context->expert_activation_buffer, /*expert_offset=*/0,
339378
&model->block_weight_buffers[n], /*weight_block_offset=*/0,
340379
&model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
341380
&model->block_weight_buffers[n], /*bias_offset=*/model->mlp_swiglu_bias_offset,
342-
&model->swiglu_activation_buffer, /*output_offset=*/0,
381+
&context->swiglu_activation_buffer, /*output_offset=*/0,
343382
model->swiglu_limit,
344383
model->per_expert_block_weight_size,
345384
num_output_tokens,
@@ -355,12 +394,12 @@ static enum gptoss_status process_batch(
355394
&command_buffer,
356395
&model->f32_mf4w_moe_matmul_fn,
357396
/*threadgroup_size=*/512,
358-
&model->swiglu_activation_buffer, /*input_offset=*/0,
359-
&model->expert_activation_buffer, /*expert_offset=*/0,
397+
&context->swiglu_activation_buffer, /*input_offset=*/0,
398+
&context->expert_activation_buffer, /*expert_offset=*/0,
360399
&model->block_weight_buffers[n], /*weight_block_offset=*/model->mlp_out_block_offset,
361400
&model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_out_scale_offset,
362401
&model->block_weight_buffers[n], /*bias_offset=*/model->mlp_out_bias_offset,
363-
&model->moe_activation_buffer, /*output_offset=*/0,
402+
&context->moe_activation_buffer, /*output_offset=*/0,
364403
model->per_expert_block_weight_size,
365404
num_output_tokens,
366405
model->num_active_experts,
@@ -376,11 +415,11 @@ static enum gptoss_status process_batch(
376415
&model->f32_accumulate_e4_fn,
377416
/*threadgroup_size=*/256,
378417
model->max_threadgroups,
379-
&model->moe_activation_buffer,
418+
&context->moe_activation_buffer,
380419
/*input_offset=*/0,
381-
&model->expert_activation_buffer,
420+
&context->expert_activation_buffer,
382421
/*expert_offset=*/0,
383-
&model->residual_activation_buffer,
422+
&context->residual_activation_buffer,
384423
/*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
385424
model->embedding_dim,
386425
num_output_tokens,
@@ -395,11 +434,11 @@ static enum gptoss_status process_batch(
395434
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
396435
&command_buffer,
397436
&model->f32_bf16w_rmsnorm_fn,
398-
&model->residual_activation_buffer,
437+
&context->residual_activation_buffer,
399438
/*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
400439
&model->shared_weight_buffer,
401440
/*weight_offset=*/model->rmsnorm_weight_offset,
402-
&model->rmsnorm_activation_buffer,
441+
&context->rmsnorm_activation_buffer,
403442
/*output_offset=*/0,
404443
/*num_tokens=*/num_output_tokens,
405444
/*num_channels=*/model->embedding_dim,
@@ -424,7 +463,7 @@ static enum gptoss_status process_batch(
424463
&model->f32_bf16w_unembedding_fn,
425464
/*threadgroup_size=*/256,
426465
model->max_threadgroups,
427-
&model->rmsnorm_activation_buffer,
466+
&context->rmsnorm_activation_buffer,
428467
/*input_offset=*/0,
429468
&model->shared_weight_buffer,
430469
/*weight_offset=*/model->unembedding_weight_offset,
@@ -700,6 +739,17 @@ enum gptoss_status GPTOSS_ABI gptoss_context_release(
700739
{
701740
if (context != NULL) {
702741
if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) {
742+
// Activation buffers
743+
gptoss_metal_buffer_release(&context->residual_activation_buffer);
744+
gptoss_metal_buffer_release(&context->rmsnorm_activation_buffer);
745+
gptoss_metal_buffer_release(&context->qkv_activation_buffer);
746+
gptoss_metal_buffer_release(&context->sdpa_activation_buffer);
747+
gptoss_metal_buffer_release(&context->gate_activation_buffer);
748+
gptoss_metal_buffer_release(&context->expert_activation_buffer);
749+
gptoss_metal_buffer_release(&context->swiglu_activation_buffer);
750+
gptoss_metal_buffer_release(&context->moe_activation_buffer);
751+
752+
// Input/output buffers
703753
gptoss_metal_buffer_release(&context->token_buffer);
704754
gptoss_metal_buffer_release(&context->score_buffer);
705755
gptoss_metal_buffer_release(&context->prob_buffer);

gpt_oss/metal/source/include/internal/model.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,17 +75,6 @@ struct gptoss_model {
7575
struct gptoss_metal_function f32_sdpa_q8_d64_fn;
7676
struct gptoss_metal_function f32_softmax_fn;
7777

78-
// Activation buffers.
79-
// TODO: merge into a single buffer.
80-
struct gptoss_metal_buffer residual_activation_buffer; // Residual stream
81-
struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output
82-
struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output
83-
struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output
84-
struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output
85-
struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions
86-
struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output
87-
struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert)
88-
8978
size_t per_block_shared_weights_size;
9079
size_t per_expert_block_weight_size;
9180

@@ -135,6 +124,18 @@ struct gptoss_context {
135124
size_t kvcache_size;
136125
size_t allocation_size;
137126

127+
// Activation buffers.
128+
// TODO: merge into a single buffer.
129+
struct gptoss_metal_buffer residual_activation_buffer; // Residual stream
130+
struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output
131+
struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output
132+
struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output
133+
struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output
134+
struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions
135+
struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output
136+
struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert)
137+
138+
// Input/output buffers.
138139
struct gptoss_metal_buffer token_buffer; // uint32 token IDs
139140
struct gptoss_metal_buffer score_buffer; // unembedding outputs
140141
struct gptoss_metal_buffer prob_buffer;

0 commit comments

Comments
 (0)