@@ -204,7 +204,7 @@ namespace Flux {
204204 // return: [ModulationOut, ModulationOut]
205205 auto lin = std::dynamic_pointer_cast<Linear>(blocks[" lin" ]);
206206
207- auto out = ggml_silu (ctx, vec);
207+ auto out = ggml_silu_inplace (ctx, vec);
208208 out = lin->forward (ctx, out); // [N, multiplier*dim]
209209
210210 auto m = ggml_reshape_3d (ctx, out, vec->ne [0 ], multiplier, vec->ne [1 ]); // [N, multiplier, dim]
@@ -235,8 +235,8 @@ namespace Flux {
235235 // shift: [N, C]
236236 scale = ggml_reshape_3d (ctx, scale, scale->ne [0 ], 1 , scale->ne [1 ]); // [N, 1, C]
237237 shift = ggml_reshape_3d (ctx, shift, shift->ne [0 ], 1 , shift->ne [1 ]); // [N, 1, C]
238- x = ggml_add (ctx, x, ggml_mul (ctx, x, scale));
239- x = ggml_add (ctx, x, shift);
238+ x = ggml_add_inplace (ctx, x, ggml_mul (ctx, x, scale));
239+ x = ggml_add_inplace (ctx, x, shift);
240240 return x;
241241 }
242242
@@ -346,22 +346,22 @@ namespace Flux {
346346 img_attn_out = ggml_cont (ctx, ggml_permute (ctx, img_attn_out, 0 , 2 , 1 , 3 )); // [N, n_img_token, hidden_size]
347347
348348 // calculate the img bloks
349- img = ggml_add (ctx, img, ggml_mul (ctx, img_attn->post_attention (ctx, img_attn_out), img_mod1.gate ));
349+ img = ggml_add_inplace (ctx, img, ggml_mul (ctx, img_attn->post_attention (ctx, img_attn_out), img_mod1.gate ));
350350
351351 auto img_mlp_out = img_mlp_0->forward (ctx, Flux::modulate (ctx, img_norm2->forward (ctx, img), img_mod2.shift , img_mod2.scale ));
352352 img_mlp_out = ggml_gelu_inplace (ctx, img_mlp_out);
353353 img_mlp_out = img_mlp_2->forward (ctx, img_mlp_out);
354354
355- img = ggml_add (ctx, img, ggml_mul (ctx, img_mlp_out, img_mod2.gate ));
355+ img = ggml_add_inplace (ctx, img, ggml_mul (ctx, img_mlp_out, img_mod2.gate ));
356356
357357 // calculate the txt bloks
358- txt = ggml_add (ctx, txt, ggml_mul (ctx, txt_attn->post_attention (ctx, txt_attn_out), txt_mod1.gate ));
358+ txt = ggml_add_inplace (ctx, txt, ggml_mul (ctx, txt_attn->post_attention (ctx, txt_attn_out), txt_mod1.gate ));
359359
360360 auto txt_mlp_out = txt_mlp_0->forward (ctx, Flux::modulate (ctx, txt_norm2->forward (ctx, txt), txt_mod2.shift , txt_mod2.scale ));
361361 txt_mlp_out = ggml_gelu_inplace (ctx, txt_mlp_out);
362362 txt_mlp_out = txt_mlp_2->forward (ctx, txt_mlp_out);
363363
364- txt = ggml_add (ctx, txt, ggml_mul (ctx, txt_mlp_out, txt_mod2.gate ));
364+ txt = ggml_add_inplace (ctx, txt, ggml_mul (ctx, txt_mlp_out, txt_mod2.gate ));
365365
366366 return {img, txt};
367367 }
@@ -448,7 +448,7 @@ namespace Flux {
448448 auto attn_mlp = ggml_concat (ctx, attn, ggml_gelu_inplace (ctx, mlp), 0 ); // [N, n_token, hidden_size + mlp_hidden_dim]
449449 auto output = linear2->forward (ctx, attn_mlp); // [N, n_token, hidden_size]
450450
451- output = ggml_add (ctx, x, ggml_mul (ctx, output, mod.gate ));
451+ output = ggml_add_inplace (ctx, x, ggml_mul (ctx, output, mod.gate ));
452452 return output;
453453 }
454454 };
@@ -473,7 +473,7 @@ namespace Flux {
473473 auto linear = std::dynamic_pointer_cast<Linear>(blocks[" linear" ]);
474474 auto adaLN_modulation_1 = std::dynamic_pointer_cast<Linear>(blocks[" adaLN_modulation.1" ]);
475475
476- auto m = adaLN_modulation_1->forward (ctx, ggml_silu (ctx, c)); // [N, 2 * hidden_size]
476+ auto m = adaLN_modulation_1->forward (ctx, ggml_silu_inplace (ctx, c)); // [N, 2 * hidden_size]
477477 m = ggml_reshape_3d (ctx, m, c->ne [0 ], 2 , c->ne [1 ]); // [N, 2, hidden_size]
478478 m = ggml_cont (ctx, ggml_permute (ctx, m, 0 , 2 , 1 , 3 )); // [2, N, hidden_size]
479479
@@ -741,10 +741,10 @@ namespace Flux {
741741 auto guidance_in = std::dynamic_pointer_cast<MLPEmbedder>(blocks[" guidance_in" ]);
742742 // bf16 and fp16 result is different
743743 auto g_in = ggml_nn_timestep_embedding (ctx, guidance, 256 , 10000 , 1000 .f );
744- vec = ggml_add (ctx, vec, guidance_in->forward (ctx, g_in));
744+ vec = ggml_add_inplace (ctx, vec, guidance_in->forward (ctx, g_in));
745745 }
746746
747- vec = ggml_add (ctx, vec, vector_in->forward (ctx, y));
747+ vec = ggml_add_inplace (ctx, vec, vector_in->forward (ctx, y));
748748 txt = txt_in->forward (ctx, txt);
749749
750750 for (int i = 0 ; i < params.depth ; i++) {
0 commit comments