@@ -850,12 +850,15 @@ namespace Flux {
850850
851851 // auto arrange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends
852852 auto arrange = y;
853- auto modulation_index = ggml_nn_timestep_embedding (ctx, arrange, 32 , 10000 , 1000 .f );
853+ auto modulation_index = ggml_nn_timestep_embedding (ctx, arrange, 32 , 10000 , 1000 .f );// [1, 344, 32]
854+
855+ // Batch broadcast (will it ever be useful)
856+ modulation_index = ggml_repeat (ctx, modulation_index, ggml_new_tensor_4d (ctx, GGML_TYPE_F32, modulation_index->ne [0 ], modulation_index->ne [1 ], img->ne [2 ], modulation_index->ne [3 ]));// [N, 344, 32]
854857
855- auto timestep_guidance = ggml_concat (ctx, distill_timestep, distill_guidance, 0 );
856- timestep_guidance = ggml_repeat (ctx, distill_timestep, modulation_index);
857- // TODO Batch broadcast?
858858
859+ auto timestep_guidance = ggml_concat (ctx, distill_timestep, distill_guidance, 0 ); // [N, 1, 32]
860+ timestep_guidance = ggml_repeat (ctx, timestep_guidance, modulation_index); // [N, 344, 32]
861+
859862 vec = ggml_concat (ctx, timestep_guidance, modulation_index, 0 ); // [N, 344, 64]
860863 vec = approx->forward (ctx, vec); // [N, 344, hidden_size]
861864
@@ -1064,7 +1067,7 @@ namespace Flux {
10641067 set_backend_tensor_data (y, range.data ());
10651068 }
10661069 timesteps = to_backend (timesteps);
1067- if (flux_params.guidance_embed ) {
1070+ if (flux_params.guidance_embed || flux_params. is_chroma ) {
10681071 guidance = to_backend (guidance);
10691072 }
10701073
0 commit comments