Skip to content

Commit 4080c29

Browse files
committed
Refactor: Flexible Flux arch
1 parent e7eabd3 commit 4080c29

File tree

5 files changed

+47
-40
lines changed

5 files changed

+47
-40
lines changed

diffusion_model.hpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,8 @@ struct FluxModel : public DiffusionModel {
133133

134134
FluxModel(ggml_backend_t backend,
135135
std::map<std::string, enum ggml_type>& tensor_types,
136-
SDVersion version = VERSION_FLUX_DEV,
137136
bool flash_attn = false)
138-
: flux(backend, tensor_types, "model.diffusion_model", version, flash_attn) {
137+
: flux(backend, tensor_types, "model.diffusion_model", flash_attn) {
139138
}
140139

141140
void alloc_params_buffer() {

flux.hpp

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -834,16 +834,43 @@ namespace Flux {
834834
FluxRunner(ggml_backend_t backend,
835835
std::map<std::string, enum ggml_type>& tensor_types = empty_tensor_types,
836836
const std::string prefix = "",
837-
SDVersion version = VERSION_FLUX_DEV,
838-
bool flash_attn = false)
837+
bool flash_attn = false)
839838
: GGMLRunner(backend) {
840-
flux_params.flash_attn = flash_attn;
841-
if (version == VERSION_FLUX_SCHNELL) {
842-
flux_params.guidance_embed = false;
839+
flux_params.flash_attn = flash_attn;
840+
flux_params.guidance_embed = false;
841+
flux_params.depth = 0;
842+
flux_params.depth_single_blocks = 0;
843+
for (auto pair : tensor_types) {
844+
std::string tensor_name = pair.first;
845+
if (tensor_name.find("model.diffusion_model.") == std::string::npos)
846+
continue;
847+
if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) {
848+
// not schnell
849+
flux_params.guidance_embed = true;
850+
}
851+
size_t db = tensor_name.find("double_blocks.");
852+
if (db != std::string::npos) {
853+
tensor_name = tensor_name.substr(db); // remove prefix
854+
int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str());
855+
if (block_depth + 1 > flux_params.depth) {
856+
flux_params.depth = block_depth + 1;
857+
}
858+
}
859+
size_t sb = tensor_name.find("single_blocks.");
860+
if (sb != std::string::npos) {
861+
tensor_name = tensor_name.substr(sb); // remove prefix
862+
int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str());
863+
if (block_depth + 1 > flux_params.depth_single_blocks) {
864+
flux_params.depth_single_blocks = block_depth + 1;
865+
}
866+
}
843867
}
844-
if (version == VERSION_FLUX_LITE) {
845-
flux_params.depth = 8;
868+
869+
LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks);
870+
if (!flux_params.guidance_embed) {
871+
LOG_INFO("Flux guidance is disabled (Schnell mode)");
846872
}
873+
847874
flux = Flux(flux_params);
848875
flux.init(params_ctx, tensor_types, prefix);
849876
}

model.cpp

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,18 +1459,9 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
14591459

14601460
SDVersion ModelLoader::get_sd_version() {
14611461
TensorStorage token_embedding_weight;
1462-
bool is_flux = false;
1463-
bool is_schnell = true;
1464-
bool is_lite = true;
14651462
for (auto& tensor_storage : tensor_storages) {
1466-
if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
1467-
is_schnell = false;
1468-
}
14691463
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
1470-
is_flux = true;
1471-
}
1472-
if (tensor_storage.name.find("model.diffusion_model.double_blocks.8") != std::string::npos) {
1473-
is_lite = false;
1464+
return VERSION_FLUX;
14741465
}
14751466
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
14761467
return VERSION_SD3;
@@ -1495,16 +1486,7 @@ SDVersion ModelLoader::get_sd_version() {
14951486
// break;
14961487
}
14971488
}
1498-
if (is_flux) {
1499-
if (is_schnell) {
1500-
GGML_ASSERT(!is_lite);
1501-
return VERSION_FLUX_SCHNELL;
1502-
} else if (is_lite) {
1503-
return VERSION_FLUX_LITE;
1504-
} else {
1505-
return VERSION_FLUX_DEV;
1506-
}
1507-
}
1489+
15081490
if (token_embedding_weight.ne[0] == 768) {
15091491
return VERSION_SD1;
15101492
} else if (token_embedding_weight.ne[0] == 1024) {

model.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,12 @@ enum SDVersion {
2323
VERSION_SDXL,
2424
VERSION_SVD,
2525
VERSION_SD3,
26-
VERSION_FLUX_DEV,
27-
VERSION_FLUX_SCHNELL,
28-
VERSION_FLUX_LITE,
26+
VERSION_FLUX,
2927
VERSION_COUNT,
3028
};
3129

3230
static inline bool sd_version_is_flux(SDVersion version) {
33-
if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL || version == VERSION_FLUX_LITE) {
31+
if (version == VERSION_FLUX) {
3432
return true;
3533
}
3634
return false;

stable-diffusion.cpp

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,7 @@ const char* model_version_to_str[] = {
3030
"SDXL",
3131
"SVD",
3232
"SD3.x",
33-
"Flux Dev",
34-
"Flux Schnell",
35-
"Flux Lite 8B"};
33+
"Flux"};
3634

3735
const char* sampling_methods_str[] = {
3836
"Euler A",
@@ -331,7 +329,7 @@ class StableDiffusionGGML {
331329
diffusion_model = std::make_shared<MMDiTModel>(backend, model_loader.tensor_storages_types);
332330
} else if (sd_version_is_flux(version)) {
333331
cond_stage_model = std::make_shared<FluxCLIPEmbedder>(clip_backend, model_loader.tensor_storages_types);
334-
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, version, diffusion_flash_attn);
332+
diffusion_model = std::make_shared<FluxModel>(backend, model_loader.tensor_storages_types, diffusion_flash_attn);
335333
} else {
336334
if (id_embeddings_path.find("v2") != std::string::npos) {
337335
cond_stage_model = std::make_shared<FrozenCLIPEmbedderWithCustomWords>(clip_backend, model_loader.tensor_storages_types, embeddings_path, version, PM_VERSION_2);
@@ -533,9 +531,12 @@ class StableDiffusionGGML {
533531
denoiser = std::make_shared<DiscreteFlowDenoiser>();
534532
} else if (sd_version_is_flux(version)) {
535533
LOG_INFO("running in Flux FLOW mode");
536-
float shift = 1.15f;
537-
if (version == VERSION_FLUX_SCHNELL) {
538-
shift = 1.0f; // TODO: validate
534+
float shift = 1.0f; // TODO: validate
535+
for (auto pair : model_loader.tensor_storages_types) {
536+
if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) {
537+
shift = 1.15f;
538+
break;
539+
}
539540
}
540541
denoiser = std::make_shared<FluxFlowDenoiser>(shift);
541542
} else if (is_using_v_parameterization) {

0 commit comments

Comments
 (0)