Skip to content

Commit 9e64a0e

Browse files
committed
Workaround for Chroma with flash attention, debug prints
1 parent c92e14a commit 9e64a0e

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -506,11 +506,29 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
506506
auto loadedsdver = get_loaded_sd_version(sd_ctx);
507507
if (loadedsdver == SDVersion::VERSION_FLUX)
508508
{
509-
if (!loaded_model_is_chroma(sd_ctx)) {
510-
sd_params->cfg_scale = 1; //non chroma clamp cfg scale
509+
if (loaded_model_is_chroma(sd_ctx)) {
510+
if (sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask) {
511+
if (!sd_is_quiet && sddebugmode) {
512+
printf("Chroma: flash attention is on, disabling DiT mask\n");
513+
}
514+
sd_params->chroma_use_dit_mask = false;
515+
}
516+
}
517+
else {
518+
if (sd_params->cfg_scale != 1.0f) {
519+
//non chroma clamp cfg scale
520+
if (!sd_is_quiet && sddebugmode) {
521+
printf("Flux: clamping CFG Scale to 1\n");
522+
}
523+
sd_params->cfg_scale = 1.0f;
524+
}
511525
}
512526
if (sampler == "euler a" || sampler == "k_euler_a" || sampler == "euler_a") {
513-
sampler = "euler"; //euler a broken on flux
527+
//euler a broken on flux
528+
if (!sd_is_quiet && sddebugmode) {
529+
printf("Flux: switching Euler A to Euler\n");
530+
}
531+
sampler = "euler";
514532
}
515533
}
516534

0 commit comments

Comments
 (0)