Skip to content

Commit 350136f

Browse files
committed
Detect Flux fill models
Fix Flux fill detect
1 parent d741e2d commit 350136f

File tree

3 files changed

+44
-22
lines changed

3 files changed

+44
-22
lines changed

model.cpp

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,24 +1459,33 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
14591459

14601460
SDVersion ModelLoader::get_sd_version() {
14611461
TensorStorage token_embedding_weight, input_block_weight;
1462-
bool is_xl = false;
1462+
bool input_block_checked = false;
1463+
1464+
bool is_xl = false;
1465+
bool is_flux = false;
1466+
1467+
#define found_family (is_xl || is_flux)
14631468
for (auto& tensor_storage : tensor_storages) {
1464-
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
1465-
return VERSION_FLUX;
1466-
}
1467-
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
1468-
return VERSION_SD3;
1469-
}
1470-
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) {
1471-
is_xl = true;
1472-
}
1473-
if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
1474-
is_xl = true;
1475-
}
1476-
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
1477-
return VERSION_SVD;
1469+
if (!found_family) {
1470+
if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) {
1471+
is_flux = true;
1472+
if (input_block_checked) {
1473+
break;
1474+
}
1475+
}
1476+
if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) {
1477+
return VERSION_SD3;
1478+
}
1479+
if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos) {
1480+
is_xl = true;
1481+
if (input_block_checked) {
1482+
break;
1483+
}
1484+
}
1485+
if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) {
1486+
return VERSION_SVD;
1487+
}
14781488
}
1479-
14801489
if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
14811490
tensor_storage.name == "cond_stage_model.model.token_embedding.weight" ||
14821491
tensor_storage.name == "text_model.embeddings.token_embedding.weight" ||
@@ -1486,10 +1495,12 @@ SDVersion ModelLoader::get_sd_version() {
14861495
token_embedding_weight = tensor_storage;
14871496
// break;
14881497
}
1489-
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight") {
1490-
input_block_weight = tensor_storage;
1491-
if (is_xl)
1498+
if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight") {
1499+
input_block_weight = tensor_storage;
1500+
input_block_checked = true;
1501+
if (found_family) {
14921502
break;
1503+
}
14931504
}
14941505
}
14951506
bool is_inpaint = input_block_weight.ne[2] == 9;
@@ -1499,6 +1510,15 @@ SDVersion ModelLoader::get_sd_version() {
14991510
}
15001511
return VERSION_SDXL;
15011512
}
1513+
1514+
if (is_flux) {
1515+
is_inpaint = input_block_weight.ne[0] == 384;
1516+
if (is_inpaint) {
1517+
return VERSION_FLUX_INPAINT;
1518+
}
1519+
return VERSION_FLUX;
1520+
}
1521+
15021522
if (token_embedding_weight.ne[0] == 768) {
15031523
if (is_inpaint) {
15041524
return VERSION_SD1_INPAINT;

model.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,12 @@ enum SDVersion {
2727
VERSION_SVD,
2828
VERSION_SD3,
2929
VERSION_FLUX,
30+
VERSION_FLUX_INPAINT,
3031
VERSION_COUNT,
3132
};
3233

3334
static inline bool sd_version_is_flux(SDVersion version) {
34-
if (version == VERSION_FLUX) {
35+
if (version == VERSION_FLUX || version == VERSION_FLUX_INPAINT) {
3536
return true;
3637
}
3738
return false;
@@ -66,7 +67,7 @@ static inline bool sd_version_is_sdxl(SDVersion version) {
6667
}
6768

6869
static inline bool sd_version_is_inpaint(SDVersion version) {
69-
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT) {
70+
if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_INPAINT) {
7071
return true;
7172
}
7273
return false;

stable-diffusion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ const char* model_version_to_str[] = {
3333
"SDXL Inpaint",
3434
"SVD",
3535
"SD3.x",
36-
"Flux"};
36+
"Flux",
37+
"Flux Fill"};
3738

3839
const char* sampling_methods_str[] = {
3940
"Euler A",

0 commit comments

Comments
 (0)