@@ -52,6 +52,7 @@ class ChatModelAP(IntEnum):
5252ModelTypeTagChatAudioIn = ((ChatModelAP .Text .value + ChatModelAP .AudioInput .value ) >> 1 ) << 24
5353ModelTypeTagChatImageVideoIn = ((ChatModelAP .Text .value + ChatModelAP .ImageInput .value + ChatModelAP .VideoInput .value ) >> 1 ) << 24
5454ModelTypeTagChatImageVideoAudioInAudioOut = ((ChatModelAP .Text .value + ChatModelAP .ImageInput .value + ChatModelAP .VideoInput .value + ChatModelAP .AudioInput .value + ChatModelAP .AudioOutput .value ) >> 1 ) << 24
55+ ModelTypeTagChatImageInImageOut = ((ChatModelAP .Text .value + ChatModelAP .ImageInput .value + ChatModelAP .ImageOutput .value ) >> 1 ) << 24
5556
5657class ModelType (Enum ):
5758 CHATGLM = 1
@@ -244,6 +245,8 @@ class ModelType(Enum):
244245
245246 MiniCPM_O = ModelTypeTagChatImageVideoAudioInAudioOut + 0x0000001
246247
248+ JanusPro = ModelTypeTagChatImageInImageOut + 0x0000001
249+
247250class TokenType (Enum ):
248251 UNDEFINED = 0
249252 NORMAL = 1
@@ -7657,6 +7660,283 @@ def get_weight_names(config):
76577660 weight_names .sort ()
76587661 return weight_names
76597662
7663+ class JanusConverter (BaseConverter ):
7664+ MODEL_TYPE = ModelType .JanusPro
7665+ lang_config = {}
7666+
7667+ @staticmethod
7668+ def is_proper_config (config ):
7669+ try :
7670+ return config .aligner_config ['cls' ] == 'MlpProjector' and \
7671+ config .gen_aligner_config ['cls' ] == 'MlpProjector' and \
7672+ config .gen_head_config ['cls' ] == 'vision_head' and \
7673+ config .gen_vision_config ['cls' ] == 'VQ-16' and \
7674+ config .language_config ['model_type' ] == 'llama' and \
7675+ config .model_type == 'multi_modality' and \
7676+ config .vision_config ['cls' ] == 'CLIPVisionTower'
7677+ except :
7678+ return False
7679+
7680+ @classmethod
7681+ def state_dict_pp (cls , config , state_dict ):
7682+ r = {}
7683+ for k in state_dict :
7684+ name : str = k
7685+ t : torch .Tensor = state_dict [k ]
7686+
7687+ if name .startswith ('language_model.' ):
7688+ name = name .replace ('language_model.' , '' )
7689+ r [name ] = DeepSeekConverter .pp (JanusConverter .lang_config , name , t )
7690+ elif name .startswith ('gen_vision_model' ):
7691+ name = name .replace ('.k.' , '.k_proj.' )
7692+ name = name .replace ('.q.' , '.q_proj.' )
7693+ name = name .replace ('.v.' , '.v_proj.' )
7694+ name = name .replace ('.proj_out.' , '.o_proj.' )
7695+ r [name ] = t
7696+ elif name .startswith ('vision_model' ):
7697+ name = name .replace ('.vision_tower.blocks.' , '.layers.' )
7698+ name = name .replace ('.vision_tower.' , '.' )
7699+ if '.attn.' in name :
7700+ name = name .replace ('.proj.' , '.o_proj.' )
7701+ if '.qkv.' in name :
7702+ n = t .shape [0 ] // 3
7703+ q , k , v = t .split ([n , n , n ])
7704+ r [name .replace ('.qkv.' , '.q_proj.' )] = q
7705+ r [name .replace ('.qkv.' , '.k_proj.' )] = k
7706+ r [name .replace ('.qkv.' , '.v_proj.' )] = v
7707+ else :
7708+ r [name ] = t
7709+ else :
7710+ if 'mlp.fc1.' in name :
7711+ name = name .replace ('.fc1.' , '.fc0.' )
7712+ elif 'mlp.fc2.' in name :
7713+ name = name .replace ('.fc2.' , '.fc1.' )
7714+
7715+ if name == 'vision_model.pos_embed' :
7716+ assert t .shape [0 ] == 1
7717+ t = t [0 ]
7718+ r [name ] = t
7719+ elif name .startswith ('aligner.' ) or name .startswith ('gen_aligner.' ):
7720+ name = name .replace ('.layers.0.' , '.fc0.' )
7721+ name = name .replace ('.layers.2.' , '.fc1.' )
7722+ r [name ] = t
7723+ else :
7724+ r [name ] = t
7725+
7726+ return r
7727+
7728+ @staticmethod
7729+ def dump_config (f , config , ggml_type ):
7730+ assert config .vision_config ['params' ]['model_name' ] == 'siglip_large_patch16_384'
7731+ JanusConverter .lang_config = AttributeDict (config .language_config )
7732+
7733+ JanusConverter .lang_config .hidden_act = 'silu'
7734+
7735+ DeepSeekConverter .dump_config (f , JanusConverter .lang_config , ggml_type )
7736+
7737+ @staticmethod
7738+ def get_aligner_weight_names (config , prefix ):
7739+ weight_names = [
7740+ prefix + 'fc0.bias' ,
7741+ prefix + 'fc0.weight' ,
7742+ prefix + 'fc1.bias' ,
7743+ prefix + 'fc1.weight' ,
7744+ ]
7745+ return weight_names
7746+
7747+ @staticmethod
7748+ def get_vis_model_weight_names (prefix , config ):
7749+ class ModelArgs :
7750+ codebook_size : int = 16384
7751+ codebook_embed_dim : int = 8
7752+ codebook_l2_norm : bool = True
7753+ codebook_show_usage : bool = True
7754+ commit_loss_beta : float = 0.25
7755+ entropy_loss_ratio : float = 0.0
7756+
7757+ encoder_ch_mult : List [int ] = [1 , 1 , 2 , 2 , 4 ]
7758+ decoder_ch_mult : List [int ] = [1 , 1 , 2 , 2 , 4 ]
7759+ z_channels : int = 256
7760+ dropout_p : float = 0.0
7761+
7762+ args = ModelArgs ()
7763+ weight_names = []
7764+
7765+ def get_res_names (name_prefix : str , in_channels , out_channels ):
7766+ weight_names = [
7767+ f"{ name_prefix } .conv1.bias" ,
7768+ f"{ name_prefix } .conv1.weight" ,
7769+ f"{ name_prefix } .conv2.bias" ,
7770+ f"{ name_prefix } .conv2.weight" ,
7771+ f"{ name_prefix } .norm1.bias" ,
7772+ f"{ name_prefix } .norm1.weight" ,
7773+ f"{ name_prefix } .norm2.bias" ,
7774+ f"{ name_prefix } .norm2.weight" ,
7775+ ]
7776+ if in_channels != out_channels :
7777+ weight_names += [
7778+ f"{ name_prefix } .nin_shortcut.weight" ,
7779+ f"{ name_prefix } .nin_shortcut.bias" ,
7780+ ]
7781+ return weight_names
7782+
7783+ def get_attn_names (name_prefix : str ):
7784+ weight_names = [
7785+ f"{ name_prefix } .q_proj.weight" ,
7786+ f"{ name_prefix } .q_proj.bias" ,
7787+ f"{ name_prefix } .k_proj.weight" ,
7788+ f"{ name_prefix } .k_proj.bias" ,
7789+ f"{ name_prefix } .v_proj.weight" ,
7790+ f"{ name_prefix } .v_proj.bias" ,
7791+ f"{ name_prefix } .o_proj.weight" ,
7792+ f"{ name_prefix } .o_proj.bias" ,
7793+ f"{ name_prefix } .norm.weight" ,
7794+ f"{ name_prefix } .norm.bias" ,
7795+ ]
7796+ return weight_names
7797+
7798+ def get_decoder_names (num_res_blocks = 2 ):
7799+ nonlocal args , weight_names
7800+ num_resolutions = len (args .decoder_ch_mult )
7801+
7802+ weight_names += [
7803+ f"{ prefix } decoder.conv_in.bias" ,
7804+ f"{ prefix } decoder.conv_in.weight" ,
7805+ f"{ prefix } decoder.conv_out.bias" ,
7806+ f"{ prefix } decoder.conv_out.weight" ,
7807+ f"{ prefix } decoder.norm_out.bias" ,
7808+ f"{ prefix } decoder.norm_out.weight" ,
7809+ ]
7810+
7811+ ch_mult = args .decoder_ch_mult
7812+ block_in = ch_mult [num_resolutions - 1 ]
7813+
7814+ weight_names += get_res_names (prefix + f"decoder.mid.0" , 1 , 1 )
7815+ weight_names += get_attn_names (prefix + f"decoder.mid.1" )
7816+ weight_names += get_res_names (prefix + f"decoder.mid.2" , 1 , 1 )
7817+
7818+ for i_level in range (num_resolutions ):
7819+ name_prefix = f"{ prefix } decoder.conv_blocks.{ i_level } "
7820+ block_out = ch_mult [num_resolutions - i_level - 1 ]
7821+
7822+ for j in range (num_res_blocks + 1 ):
7823+ weight_names += get_res_names (name_prefix + f".res.{ j } " , block_in , block_out )
7824+ block_in = block_out
7825+ if i_level == 0 :
7826+ weight_names += get_attn_names (name_prefix + f'.attn.{ j } ' )
7827+ if i_level != num_resolutions - 1 :
7828+ weight_names += [
7829+ f"{ name_prefix } .upsample.conv.bias" ,
7830+ f"{ name_prefix } .upsample.conv.weight" ,
7831+ ]
7832+
7833+ def get_encoder_names (num_res_blocks = 2 ):
7834+ nonlocal args , weight_names
7835+ num_resolutions = len (args .decoder_ch_mult )
7836+
7837+ weight_names += [
7838+ f"{ prefix } encoder.conv_in.bias" ,
7839+ f"{ prefix } encoder.conv_in.weight" ,
7840+ f"{ prefix } encoder.conv_out.bias" ,
7841+ f"{ prefix } encoder.conv_out.weight" ,
7842+ f"{ prefix } encoder.norm_out.bias" ,
7843+ f"{ prefix } encoder.norm_out.weight" ,
7844+ ]
7845+
7846+ weight_names += get_res_names (prefix + f"encoder.mid.0" , 1 , 1 )
7847+ weight_names += get_attn_names (prefix + f"encoder.mid.1" )
7848+ weight_names += get_res_names (prefix + f"encoder.mid.2" , 1 , 1 )
7849+
7850+ ch_mult = args .encoder_ch_mult
7851+ in_ch_mult = (1 ,) + tuple (ch_mult )
7852+
7853+ for i_level in range (num_resolutions ):
7854+ name_prefix = f"{ prefix } encoder.conv_blocks.{ i_level } "
7855+
7856+ block_in = in_ch_mult [i_level ]
7857+ block_out = ch_mult [i_level ]
7858+
7859+ for j in range (num_res_blocks ):
7860+ weight_names += get_res_names (name_prefix + f".res.{ j } " , block_in , block_out )
7861+ block_in = block_out
7862+ if i_level == num_resolutions - 1 :
7863+ weight_names += get_attn_names (name_prefix + f'.attn.{ j } ' )
7864+
7865+ if i_level != num_resolutions - 1 :
7866+ weight_names += [
7867+ f"{ name_prefix } .downsample.conv.bias" ,
7868+ f"{ name_prefix } .downsample.conv.weight" ,
7869+ ]
7870+
7871+ get_decoder_names ()
7872+ get_encoder_names ()
7873+ return weight_names
7874+
7875+ @staticmethod
7876+ def get_weight_names (config ):
7877+ weight_names = DeepSeekConverter .get_weight_names (JanusConverter .lang_config )
7878+ weight_names += JanusConverter .get_aligner_weight_names (config .aligner_config , 'aligner.' )
7879+ weight_names += JanusConverter .get_aligner_weight_names (config .gen_aligner_config , 'gen_aligner.' )
7880+ weight_names += ["gen_embed.weight" ,
7881+ "gen_head.output_mlp_projector.bias" ,
7882+ "gen_head.output_mlp_projector.weight" ,
7883+ "gen_head.vision_head.bias" ,
7884+ "gen_head.vision_head.weight" ,
7885+ "gen_vision_model.post_quant_conv.bias" ,
7886+ "gen_vision_model.post_quant_conv.weight" ,
7887+ "gen_vision_model.quant_conv.bias" ,
7888+ "gen_vision_model.quant_conv.weight" ,
7889+ "gen_vision_model.quantize.codebook_used" ,
7890+ "gen_vision_model.quantize.embedding.weight" ,
7891+ "vision_model.norm.bias" ,
7892+ "vision_model.norm.weight" ,
7893+ "vision_model.patch_embed.proj.bias" ,
7894+ "vision_model.patch_embed.proj.weight" ,
7895+ "vision_model.pos_embed" ,
7896+ # attn_pool is not used, see
7897+ # https://github.com/deepseek-ai/Janus/blob/1daa72fa409002d40931bd7b36a9280362469ead/janus/models/siglip_vit.py#L667
7898+ ]
7899+
7900+ vision_cfg = AttributeDict ({
7901+ "image_size" : 384 ,
7902+ "patch_size" : 16 ,
7903+ "width" : 1024 ,
7904+ "layers" : 24 ,
7905+ "heads" : 16 ,
7906+ "mlp_ratio" : 4 ,
7907+ "global_pool" : "map" ,
7908+ "use_checkpoint" : False ,
7909+ })
7910+ select_layer = config .vision_config ['params' ]['select_layer' ]
7911+ if select_layer <= 0 :
7912+ layers = min (vision_cfg .layers , vision_cfg .layers + select_layer + 1 )
7913+ else :
7914+ layers = min (vision_cfg .layers , select_layer )
7915+
7916+ for i in range (layers ):
7917+ weight_names += [
7918+ f"vision_model.layers.{ i } .attn.q_proj.bias" ,
7919+ f"vision_model.layers.{ i } .attn.q_proj.weight" ,
7920+ f"vision_model.layers.{ i } .attn.k_proj.bias" ,
7921+ f"vision_model.layers.{ i } .attn.k_proj.weight" ,
7922+ f"vision_model.layers.{ i } .attn.v_proj.bias" ,
7923+ f"vision_model.layers.{ i } .attn.v_proj.weight" ,
7924+ f"vision_model.layers.{ i } .attn.o_proj.bias" ,
7925+ f"vision_model.layers.{ i } .attn.o_proj.weight" ,
7926+ f"vision_model.layers.{ i } .mlp.fc0.bias" ,
7927+ f"vision_model.layers.{ i } .mlp.fc0.weight" ,
7928+ f"vision_model.layers.{ i } .mlp.fc1.bias" ,
7929+ f"vision_model.layers.{ i } .mlp.fc1.weight" ,
7930+ f"vision_model.layers.{ i } .norm1.bias" ,
7931+ f"vision_model.layers.{ i } .norm1.weight" ,
7932+ f"vision_model.layers.{ i } .norm2.bias" ,
7933+ f"vision_model.layers.{ i } .norm2.weight" ,
7934+ ]
7935+
7936+ weight_names += JanusConverter .get_vis_model_weight_names ('gen_vision_model.' , config .gen_vision_config )
7937+
7938+ return weight_names
7939+
76607940def convert_grok_1_base (args , vocab , ggml_type ):
76617941 def ffn_size (emb_size , widening_factor ):
76627942 _ffn_size = int (widening_factor * emb_size ) * 2 // 3
@@ -8262,6 +8542,9 @@ def main():
82628542 ApertusConverter .convert (config , model_files , vocab , ggml_type , args .save_path )
82638543 elif arch .endswith ('GroveMoeForCausalLM' ):
82648544 GroveMoEConverter .convert (config , model_files , vocab , ggml_type , args .save_path )
8545+ elif arch == 'MultiModalityCausalLM' :
8546+ assert JanusConverter .is_proper_config (config )
8547+ JanusConverter .convert (config , model_files , vocab , ggml_type , args .save_path )
82658548 elif arch == 'deepseek-r1-distill-qwen3' :
82668549 QWen3Converter .MODEL_TYPE = ModelType .DeepSeek_R1_Distill_QWen3
82678550 QWen3Converter .convert (config , model_files , vocab , ggml_type , args .save_path )
0 commit comments