Skip to content

Commit ad39011

Browse files
committed
Fix small mistake (still broken)
1 parent f7ad456 commit ad39011

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

flux.hpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -850,12 +850,15 @@ namespace Flux {
850850

851851
// auto arrange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends
852852
auto arrange = y;
853-
auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f);
853+
auto modulation_index = ggml_nn_timestep_embedding(ctx, arrange, 32, 10000, 1000.f);// [1, 344, 32]
854+
855+
// Batch broadcast (will it ever be useful)
856+
modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_4d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2], modulation_index->ne[3]));// [N, 344, 32]
854857

855-
auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0);
856-
timestep_guidance = ggml_repeat(ctx, distill_timestep, modulation_index);
857-
// TODO Batch broadcast?
858858

859+
auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32]
860+
timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32]
861+
859862
vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64]
860863
vec = approx->forward(ctx, vec); // [N, 344, hidden_size]
861864

@@ -1064,7 +1067,7 @@ namespace Flux {
10641067
set_backend_tensor_data(y, range.data());
10651068
}
10661069
timesteps = to_backend(timesteps);
1067-
if (flux_params.guidance_embed) {
1070+
if (flux_params.guidance_embed || flux_params.is_chroma) {
10681071
guidance = to_backend(guidance);
10691072
}
10701073

0 commit comments

Comments
 (0)