@@ -558,6 +558,26 @@ std::string convert_tensor_name(std::string name) {
558558 return new_name;
559559}
560560
561+ void add_preprocess_tensor_storage_types (std::map<std::string, enum ggml_type>& tensor_storages_types, std::string name, enum ggml_type type) {
562+ std::string new_name = convert_tensor_name (name);
563+
564+ if (new_name.find (" cond_stage_model" ) != std::string::npos && ends_with (new_name, " attn.in_proj_weight" )) {
565+ size_t prefix_size = new_name.find (" attn.in_proj_weight" );
566+ std::string prefix = new_name.substr (0 , prefix_size);
567+ tensor_storages_types[prefix + " self_attn.q_proj.weight" ] = type;
568+ tensor_storages_types[prefix + " self_attn.k_proj.weight" ] = type;
569+ tensor_storages_types[prefix + " self_attn.v_proj.weight" ] = type;
570+ } else if (new_name.find (" cond_stage_model" ) != std::string::npos && ends_with (new_name, " attn.in_proj_bias" )) {
571+ size_t prefix_size = new_name.find (" attn.in_proj_bias" );
572+ std::string prefix = new_name.substr (0 , prefix_size);
573+ tensor_storages_types[prefix + " self_attn.q_proj.bias" ] = type;
574+ tensor_storages_types[prefix + " self_attn.k_proj.bias" ] = type;
575+ tensor_storages_types[prefix + " self_attn.v_proj.bias" ] = type;
576+ } else {
577+ tensor_storages_types[new_name] = type;
578+ }
579+ }
580+
561581void preprocess_tensor (TensorStorage tensor_storage,
562582 std::vector<TensorStorage>& processed_tensor_storages) {
563583 std::vector<TensorStorage> result;
@@ -927,7 +947,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s
927947 GGML_ASSERT (ggml_nbytes (dummy) == tensor_storage.nbytes ());
928948
929949 tensor_storages.push_back (tensor_storage);
930- tensor_storages_types[ tensor_storage.name ] = tensor_storage.type ;
950+ add_preprocess_tensor_storage_types ( tensor_storages_types, tensor_storage.name , tensor_storage.type ) ;
931951 }
932952
933953 gguf_free (ctx_gguf_);
@@ -1072,7 +1092,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
10721092 }
10731093
10741094 tensor_storages.push_back (tensor_storage);
1075- tensor_storages_types[ tensor_storage.name ] = tensor_storage.type ;
1095+ add_preprocess_tensor_storage_types ( tensor_storages_types, tensor_storage.name , tensor_storage.type ) ;
10761096
10771097 // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str());
10781098 }
@@ -1403,7 +1423,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer,
14031423 // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str());
14041424 reader.tensor_storage .name = prefix + reader.tensor_storage .name ;
14051425 tensor_storages.push_back (reader.tensor_storage );
1406- tensor_storages_types[ reader.tensor_storage .name ] = reader.tensor_storage .type ;
1426+ add_preprocess_tensor_storage_types ( tensor_storages_types, reader.tensor_storage .name , reader.tensor_storage .type ) ;
14071427
14081428 // LOG_DEBUG("%s", reader.tensor_storage.name.c_str());
14091429 // reset
@@ -1461,10 +1481,10 @@ SDVersion ModelLoader::get_sd_version() {
14611481 TensorStorage token_embedding_weight, input_block_weight;
14621482 bool input_block_checked = false ;
14631483
1464- bool has_multiple_encoders = false ;
1465- bool is_unet = false ;
1484+ bool has_multiple_encoders = false ;
1485+ bool is_unet = false ;
14661486
1467- bool is_xl = false ;
1487+ bool is_xl = false ;
14681488 bool is_flux = false ;
14691489
14701490#define found_family (is_xl || is_flux)
@@ -1481,7 +1501,7 @@ SDVersion ModelLoader::get_sd_version() {
14811501 }
14821502 if (tensor_storage.name .find (" model.diffusion_model.input_blocks." ) != std::string::npos) {
14831503 is_unet = true ;
1484- if (has_multiple_encoders){
1504+ if (has_multiple_encoders) {
14851505 is_xl = true ;
14861506 if (input_block_checked) {
14871507 break ;
@@ -1490,7 +1510,7 @@ SDVersion ModelLoader::get_sd_version() {
14901510 }
14911511 if (tensor_storage.name .find (" conditioner.embedders.1" ) != std::string::npos || tensor_storage.name .find (" cond_stage_model.1" ) != std::string::npos) {
14921512 has_multiple_encoders = true ;
1493- if (is_unet){
1513+ if (is_unet) {
14941514 is_xl = true ;
14951515 if (input_block_checked) {
14961516 break ;
@@ -1635,11 +1655,20 @@ ggml_type ModelLoader::get_vae_wtype() {
16351655void ModelLoader::set_wtype_override (ggml_type wtype, std::string prefix) {
16361656 for (auto & pair : tensor_storages_types) {
16371657 if (prefix.size () < 1 || pair.first .substr (0 , prefix.size ()) == prefix) {
1658+ bool found = false ;
16381659 for (auto & tensor_storage : tensor_storages) {
1639- if (tensor_storage.name == pair.first ) {
1640- if (tensor_should_be_converted (tensor_storage, wtype)) {
1641- pair.second = wtype;
1660+ std::map<std::string, ggml_type> temp;
1661+ add_preprocess_tensor_storage_types (temp, tensor_storage.name , tensor_storage.type );
1662+ for (auto & preprocessed_name : temp) {
1663+ if (preprocessed_name.first == pair.first ) {
1664+ if (tensor_should_be_converted (tensor_storage, wtype)) {
1665+ pair.second = wtype;
1666+ }
1667+ found = true ;
1668+ break ;
16421669 }
1670+ }
1671+ if (found) {
16431672 break ;
16441673 }
16451674 }
0 commit comments