@@ -338,6 +338,11 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
338338 params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask ;
339339 params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad ;
340340
341+ if (params.chroma_use_dit_mask && params.diffusion_flash_attn ) {
342+ // note we don't know yet if it's a Chroma model
343+ params.chroma_use_dit_mask = false ;
344+ }
345+
341346 sd_ctx = new_sd_ctx (¶ms);
342347
343348 if (sd_ctx == NULL ) {
@@ -346,6 +351,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
346351 return false ;
347352 }
348353
354+ if (!sd_is_quiet) {
355+ if (loaded_model_is_chroma (sd_ctx) && sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask )
356+ {
357+ printf (" Chroma: flash attention is on, disabling DiT mask (this will lower image quality)\n " );
358+ // disabled before loading
359+ }
360+ }
361+
349362 std::filesystem::path mpath (inputs.model_filename );
350363 sdmodelfilename = mpath.filename ().string ();
351364
@@ -528,22 +541,12 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
528541 auto loadedsdver = get_loaded_sd_version (sd_ctx);
529542 if (loadedsdver == SDVersion::VERSION_FLUX)
530543 {
531- if (loaded_model_is_chroma (sd_ctx)) {
532- if (sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask ) {
533- if (!sd_is_quiet && sddebugmode) {
534- printf (" Chroma: flash attention is on, disabling DiT mask\n " );
535- }
536- sd_params->chroma_use_dit_mask = false ;
537- }
538- }
539- else {
540- if (sd_params->cfg_scale != 1 .0f ) {
541- // non chroma clamp cfg scale
542- if (!sd_is_quiet && sddebugmode) {
543- printf (" Flux: clamping CFG Scale to 1\n " );
544- }
545- sd_params->cfg_scale = 1 .0f ;
544+ if (!loaded_model_is_chroma (sd_ctx) && sd_params->cfg_scale != 1 .0f ) {
545+ // non chroma clamp cfg scale
546+ if (!sd_is_quiet && sddebugmode) {
547+ printf (" Flux: clamping CFG Scale to 1\n " );
546548 }
549+ sd_params->cfg_scale = 1 .0f ;
547550 }
548551 if (sampler == " euler a" || sampler == " k_euler_a" || sampler == " euler_a" ) {
549552 // euler a broken on flux
0 commit comments