@@ -744,10 +744,10 @@ namespace Flux {
744744 return ids;
745745 }
746746
747+
747748 // Generate positional embeddings
748749 std::vector<float > gen_pe (int h, int w, int patch_size, int bs, int context_len, std::vector<ggml_tensor*> ref_latents, int theta, const std::vector<int >& axes_dim) {
749750 std::vector<std::vector<float >> ids = gen_ids (h, w, patch_size, bs, context_len, ref_latents);
750-
751751 std::vector<std::vector<float >> trans_ids = transpose (ids);
752752 size_t pos_len = ids.size ();
753753 int num_axes = axes_dim.size ();
@@ -872,7 +872,7 @@ namespace Flux {
872872 struct ggml_tensor * y,
873873 struct ggml_tensor * guidance,
874874 struct ggml_tensor * pe,
875- struct ggml_tensor * arange = NULL ,
875+ struct ggml_tensor * mod_index_arange = NULL ,
876876 std::vector<int > skip_layers = {}) {
877877 auto img_in = std::dynamic_pointer_cast<Linear>(blocks[" img_in" ]);
878878 auto txt_in = std::dynamic_pointer_cast<Linear>(blocks[" txt_in" ]);
@@ -887,9 +887,10 @@ namespace Flux {
887887 auto distill_timestep = ggml_nn_timestep_embedding (ctx, timesteps, 16 , 10000 , 1000 .f );
888888 auto distill_guidance = ggml_nn_timestep_embedding (ctx, guidance, 16 , 10000 , 1000 .f );
889889
890- // auto arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); // Not working on a lot of backends, precomputing it on CPU instead
890+ // auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1);
891+ // ggml_arange tot working on a lot of backends, precomputing it on CPU instead
891892 GGML_ASSERT (arange != NULL );
892- auto modulation_index = ggml_nn_timestep_embedding (ctx, arange , 32 , 10000 , 1000 .f ); // [1, 344, 32]
893+ auto modulation_index = ggml_nn_timestep_embedding (ctx, mod_index_arange , 32 , 10000 , 1000 .f ); // [1, 344, 32]
893894
894895 // Batch broadcast (will it ever be useful)
895896 modulation_index = ggml_repeat (ctx, modulation_index, ggml_new_tensor_3d (ctx, GGML_TYPE_F32, modulation_index->ne [0 ], modulation_index->ne [1 ], img->ne [2 ])); // [N, 344, 32]
@@ -982,7 +983,7 @@ namespace Flux {
982983 struct ggml_tensor * y,
983984 struct ggml_tensor * guidance,
984985 struct ggml_tensor * pe,
985- struct ggml_tensor * arange = NULL ,
986+ struct ggml_tensor * mod_index_arange = NULL ,
986987 std::vector<ggml_tensor*> ref_latents = {},
987988 std::vector<int > skip_layers = {}) {
988989 // Forward pass of DiT.
@@ -1024,7 +1025,7 @@ namespace Flux {
10241025 }
10251026 }
10261027
1027- auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, arange , skip_layers); // [N, num_tokens, C * patch_size * patch_size]
1028+ auto out = forward_orig (ctx, img, context, timestep, y, guidance, pe, mod_index_arange , skip_layers); // [N, num_tokens, C * patch_size * patch_size]
10281029 if (out->ne [1 ] > img_tokens) {
10291030 out = ggml_cont (ctx, ggml_permute (ctx, out, 0 , 2 , 1 , 3 )); // [num_tokens, N, C * patch_size * patch_size]
10301031 out = ggml_view_3d (ctx, out, out->ne [0 ], out->ne [1 ], img_tokens, out->nb [1 ], out->nb [2 ], 0 );
@@ -1044,15 +1045,18 @@ namespace Flux {
10441045 public:
10451046 FluxParams flux_params;
10461047 Flux flux;
1047- std::vector<float > pe_vec, range; // for cache
1048+ std::vector<float > pe_vec;
1049+ std::vector<float > mod_index_arange_vec; // for cache
10481050 SDVersion version;
1051+ bool use_mask = false ;
10491052
10501053 FluxRunner (ggml_backend_t backend,
10511054 std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
10521055 const std::string prefix = " " ,
10531056 SDVersion version = VERSION_FLUX,
1054- bool flash_attn = false )
1055- : GGMLRunner(backend) {
1057+ bool flash_attn = false ,
1058+ bool use_mask = false )
1059+ : GGMLRunner(backend), use_mask(use_mask) {
10561060 flux_params.flash_attn = flash_attn;
10571061 flux_params.guidance_embed = false ;
10581062 flux_params.depth = 0 ;
@@ -1116,51 +1120,28 @@ namespace Flux {
11161120 struct ggml_tensor * y,
11171121 struct ggml_tensor * guidance,
11181122 std::vector<ggml_tensor*> ref_latents = {},
1119- std::vector<int > skip_layers = std::vector< int >() ) {
1123+ std::vector<int > skip_layers = {} ) {
11201124 GGML_ASSERT (x->ne [3 ] == 1 );
11211125 struct ggml_cgraph * gf = ggml_new_graph_custom (compute_ctx, FLUX_GRAPH_SIZE, false );
11221126
1123- struct ggml_tensor * precompute_arange = NULL ;
1127+ struct ggml_tensor * mod_index_arange = NULL ;
11241128
11251129 x = to_backend (x);
11261130 context = to_backend (context);
11271131 if (c_concat != NULL ) {
11281132 c_concat = to_backend (c_concat);
11291133 }
1130-
11311134 if (flux_params.is_chroma ) {
1132- const char * SD_CHROMA_ENABLE_GUIDANCE = getenv (" SD_CHROMA_ENABLE_GUIDANCE" );
1133- bool disable_guidance = true ;
1134- if (SD_CHROMA_ENABLE_GUIDANCE != NULL ) {
1135- std::string enable_guidance_str = SD_CHROMA_ENABLE_GUIDANCE;
1136- if (enable_guidance_str == " ON" || enable_guidance_str == " TRUE" ) {
1137- LOG_WARN (" Chroma guidance has been enabled. Image might be broken. (SD_CHROMA_ENABLE_GUIDANCE env variable to \" OFF\" to disable)" , SD_CHROMA_ENABLE_GUIDANCE);
1138- disable_guidance = false ;
1139- } else if (enable_guidance_str != " OFF" && enable_guidance_str != " FALSE" ) {
1140- LOG_WARN (" SD_CHROMA_ENABLE_GUIDANCE environment variable has unexpected value. Assuming default (\" OFF\" ). (Expected \" ON\" /\" TRUE\" or\" OFF\" /\" FALSE\" , got \" %s\" )" , SD_CHROMA_ENABLE_GUIDANCE);
1141- }
1142- }
1143- if (disable_guidance) {
1144- // LOG_DEBUG("Forcing guidance to 0 for chroma model (SD_CHROMA_ENABLE_GUIDANCE env variable to \"ON\" to enable)");
1145- guidance = ggml_set_f32 (guidance, 0 );
1146- }
1135+ guidance = ggml_set_f32 (guidance, 0 );
11471136
1148-
1149- const char * SD_CHROMA_USE_DIT_MASK = getenv (" SD_CHROMA_USE_DIT_MASK" );
1150- if (SD_CHROMA_USE_DIT_MASK != nullptr ) {
1151- std::string sd_chroma_use_DiT_mask_str = SD_CHROMA_USE_DIT_MASK;
1152- if (sd_chroma_use_DiT_mask_str == " OFF" || sd_chroma_use_DiT_mask_str == " FALSE" ) {
1153- y = NULL ;
1154- } else if (sd_chroma_use_DiT_mask_str != " ON" && sd_chroma_use_DiT_mask_str != " TRUE" ) {
1155- LOG_WARN (" SD_CHROMA_USE_DIT_MASK environment variable has unexpected value. Assuming default (\" ON\" ). (Expected \" ON\" /\" TRUE\" or\" OFF\" /\" FALSE\" , got \" %s\" )" , SD_CHROMA_USE_DIT_MASK);
1156- }
1137+ if (!use_mask) {
1138+ y = NULL ;
11571139 }
11581140
1159- // ggml_arrange is not working on some backends, and y isn't used, so let's reuse y to precompute it
1160- range = arange (0 , 344 );
1161- precompute_arange = ggml_new_tensor_1d (compute_ctx, GGML_TYPE_F32, range.size ());
1162- set_backend_tensor_data (precompute_arange, range.data ());
1163- // y = NULL;
1141+ // ggml_arange is not working on some backends, precompute it
1142+ mod_index_arange_vec = arange (0 , 344 );
1143+ mod_index_arange = ggml_new_tensor_1d (compute_ctx, GGML_TYPE_F32, mod_index_arange_vec.size ());
1144+ set_backend_tensor_data (mod_index_arange, mod_index_arange_vec.data ());
11641145 }
11651146 y = to_backend (y);
11661147
@@ -1189,7 +1170,7 @@ namespace Flux {
11891170 y,
11901171 guidance,
11911172 pe,
1192- precompute_arange ,
1173+ mod_index_arange ,
11931174 ref_latents,
11941175 skip_layers);
11951176
0 commit comments