@@ -1459,24 +1459,33 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s
14591459
14601460SDVersion ModelLoader::get_sd_version () {
14611461 TensorStorage token_embedding_weight, input_block_weight;
1462- bool is_xl = false ;
1462+ bool input_block_checked = false ;
1463+
1464+ bool is_xl = false ;
1465+ bool is_flux = false ;
1466+
1467+ #define found_family (is_xl || is_flux)
14631468 for (auto & tensor_storage : tensor_storages) {
1464- if (tensor_storage.name .find (" model.diffusion_model.double_blocks." ) != std::string::npos) {
1465- return VERSION_FLUX;
1466- }
1467- if (tensor_storage.name .find (" model.diffusion_model.joint_blocks." ) != std::string::npos) {
1468- return VERSION_SD3;
1469- }
1470- if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos) {
1471- is_xl = true ;
1472- }
1473- if (tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos) {
1474- is_xl = true ;
1475- }
1476- if (tensor_storage.name .find (" model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor" ) != std::string::npos) {
1477- return VERSION_SVD;
1469+ if (!found_family) {
1470+ if (tensor_storage.name .find (" model.diffusion_model.double_blocks." ) != std::string::npos) {
1471+ is_flux = true ;
1472+ if (input_block_checked) {
1473+ break ;
1474+ }
1475+ }
1476+ if (tensor_storage.name .find (" model.diffusion_model.joint_blocks." ) != std::string::npos) {
1477+ return VERSION_SD3;
1478+ }
1479+ if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos || tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos) {
1480+ is_xl = true ;
1481+ if (input_block_checked) {
1482+ break ;
1483+ }
1484+ }
1485+ if (tensor_storage.name .find (" model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor" ) != std::string::npos) {
1486+ return VERSION_SVD;
1487+ }
14781488 }
1479-
14801489 if (tensor_storage.name == " cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" ||
14811490 tensor_storage.name == " cond_stage_model.model.token_embedding.weight" ||
14821491 tensor_storage.name == " text_model.embeddings.token_embedding.weight" ||
@@ -1486,10 +1495,12 @@ SDVersion ModelLoader::get_sd_version() {
14861495 token_embedding_weight = tensor_storage;
14871496 // break;
14881497 }
1489- if (tensor_storage.name == " model.diffusion_model.input_blocks.0.0.weight" ) {
1490- input_block_weight = tensor_storage;
1491- if (is_xl)
1498+ if (tensor_storage.name == " model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == " model.diffusion_model.img_in.weight" ) {
1499+ input_block_weight = tensor_storage;
1500+ input_block_checked = true ;
1501+ if (found_family) {
14921502 break ;
1503+ }
14931504 }
14941505 }
14951506 bool is_inpaint = input_block_weight.ne [2 ] == 9 ;
@@ -1499,6 +1510,15 @@ SDVersion ModelLoader::get_sd_version() {
14991510 }
15001511 return VERSION_SDXL;
15011512 }
1513+
1514+ if (is_flux) {
1515+ is_inpaint = input_block_weight.ne [0 ] == 384 ;
1516+ if (is_inpaint) {
1517+ return VERSION_FLUX_INPAINT;
1518+ }
1519+ return VERSION_FLUX;
1520+ }
1521+
15021522 if (token_embedding_weight.ne [0 ] == 768 ) {
15031523 if (is_inpaint) {
15041524 return VERSION_SD1_INPAINT;
0 commit comments