@@ -432,6 +432,9 @@ def load_hparams(dir_model: Path):
432432 if "llm_config" in config :
433433 # rename for InternVL
434434 config ["text_config" ] = config ["llm_config" ]
435+ if "language_config" in config :
436+ # rename for Janus Pro
437+ config ["text_config" ] = config ["language_config" ]
435438 return config
436439
437440 @classmethod
@@ -1975,6 +1978,31 @@ def prepare_tensors(self):
19751978 raise ValueError (f"Unprocessed experts: { experts } " )
19761979
19771980
1981+ @ModelBase .register ("JanusProForCausalLM" )
1982+ class JanusProModel (TextModel ):
1983+ model_arch = gguf .MODEL_ARCH .LLAMA
1984+ undo_permute = True
1985+
1986+ def __init__ (self , * args , ** kwargs ):
1987+ super ().__init__ (* args , ** kwargs )
1988+ self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
1989+ self .hparams ["num_key_value_heads" ] = self .hparams .get ("num_key_value_heads" , 32 )
1990+ self .hparams ["hidden_size" ] = self .hparams .get ("hidden_size" , 4096 )
1991+ self .hparams ["intermediate_size" ] = self .hparams .get ("intermediate_size" , 11008 )
1992+ self .hparams ["rms_norm_eps" ] = self .hparams .get ("rms_norm_eps" , 1e-6 )
1993+
1994+ def set_gguf_parameters (self ):
1995+ super ().set_gguf_parameters ()
1996+ self .gguf_writer .add_chat_template ("janus-pro" )
1997+
1998+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
1999+ if "language_model." in name :
2000+ name = name .replace ("language_model." , "" )
2001+ return super ().modify_tensors (data_torch , name , bid )
2002+ else :
2003+ return []
2004+
2005+
19782006@ModelBase .register (
19792007 "LlavaForConditionalGeneration" , # pixtral
19802008 "Mistral3ForConditionalGeneration" , # mistral small 3.1
@@ -6222,6 +6250,9 @@ def split_str_to_n_bytes(split_str: str) -> int:
62226250
62236251
62246252def get_model_architecture (hparams : dict [str , Any ], model_type : ModelType ) -> str :
6253+ # exception: Janus Pro
6254+ if "aligner_config" in hparams :
6255+ return "JanusProForCausalLM"
62256256 # TODO @ngxson : this won't work correctly if the model has both audio & vision encoders
62266257 # maybe we should fallback to text model's arch in that case, since not many models have both
62276258 text_config = hparams .get ("text_config" , {})
0 commit comments