Skip to content

Commit d7f4622

Browse files
committed
fix Chroma workaround for flash attention
chroma_use_dit_mask is a context parameter, so changing it after creating the context has no effect.
1 parent 126104f commit d7f4622

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

otherarch/sdcpp/sdtype_adapter.cpp

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,11 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
338338
params.chroma_use_t5_mask = sd_params->chroma_use_t5_mask;
339339
params.chroma_t5_mask_pad = sd_params->chroma_t5_mask_pad;
340340

341+
if (params.chroma_use_dit_mask && params.diffusion_flash_attn) {
342+
// note we don't know yet if it's a Chroma model
343+
params.chroma_use_dit_mask = false;
344+
}
345+
341346
sd_ctx = new_sd_ctx(&params);
342347

343348
if (sd_ctx == NULL) {
@@ -346,6 +351,14 @@ bool sdtype_load_model(const sd_load_model_inputs inputs) {
346351
return false;
347352
}
348353

354+
if (!sd_is_quiet) {
355+
if (loaded_model_is_chroma(sd_ctx) && sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask)
356+
{
357+
printf("Chroma: flash attention is on, disabling DiT mask (this will lower image quality)\n");
358+
// disabled before loading
359+
}
360+
}
361+
349362
std::filesystem::path mpath(inputs.model_filename);
350363
sdmodelfilename = mpath.filename().string();
351364

@@ -528,22 +541,12 @@ sd_generation_outputs sdtype_generate(const sd_generation_inputs inputs)
528541
auto loadedsdver = get_loaded_sd_version(sd_ctx);
529542
if (loadedsdver == SDVersion::VERSION_FLUX)
530543
{
531-
if (loaded_model_is_chroma(sd_ctx)) {
532-
if (sd_params->diffusion_flash_attn && sd_params->chroma_use_dit_mask) {
533-
if (!sd_is_quiet && sddebugmode) {
534-
printf("Chroma: flash attention is on, disabling DiT mask\n");
535-
}
536-
sd_params->chroma_use_dit_mask = false;
537-
}
538-
}
539-
else {
540-
if (sd_params->cfg_scale != 1.0f) {
541-
//non chroma clamp cfg scale
542-
if (!sd_is_quiet && sddebugmode) {
543-
printf("Flux: clamping CFG Scale to 1\n");
544-
}
545-
sd_params->cfg_scale = 1.0f;
544+
if (!loaded_model_is_chroma(sd_ctx) && sd_params->cfg_scale != 1.0f) {
545+
//non chroma clamp cfg scale
546+
if (!sd_is_quiet && sddebugmode) {
547+
printf("Flux: clamping CFG Scale to 1\n");
546548
}
549+
sd_params->cfg_scale = 1.0f;
547550
}
548551
if (sampler == "euler a" || sampler == "k_euler_a" || sampler == "euler_a") {
549552
//euler a broken on flux

0 commit comments

Comments
 (0)