@@ -9802,6 +9802,113 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
98029802
98039803 return [(self .map_tensor_name (name ), data_torch )]
98049804
9805+
9806+ @ModelBase .register ("JanusForConditionalGeneration" )
9807+ class JanusProModel (LlamaModel ):
9808+ model_arch = gguf .MODEL_ARCH .LLAMA # reuse Llama arch
9809+
9810+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
9811+ # Skip vision, aligner, and generation tensors
9812+ skip_prefixes = (
9813+ 'model.vision_model.' ,
9814+ 'model.aligner.' ,
9815+ 'model.vqmodel.' ,
9816+ 'model.generation_embeddings.' ,
9817+ 'model.generation_aligner.' ,
9818+ 'model.generation_head.' ,
9819+ )
9820+ if name .startswith (skip_prefixes ):
9821+ return []
9822+
9823+ if name .startswith ('model.language_model.' ):
9824+ name = name .replace ('model.language_model.' , 'model.' )
9825+ elif name .startswith ('language_model.' ):
9826+ name = name .replace ('language_model.' , '' )
9827+
9828+ return super ().modify_tensors (data_torch , name , bid )
9829+
9830+
9831+ @ModelBase .register ("JanusForConditionalGeneration" )
9832+ class JanusProVisionModel (MmprojModel ):
9833+ def __init__ (self , * args , ** kwargs ):
9834+ super ().__init__ (* args , ** kwargs )
9835+ assert self .hparams_vision is not None
9836+ if "intermediate_size" not in self .hparams_vision :
9837+ mlp_ratio = self .hparams_vision .get ("mlp_ratio" )
9838+ hidden_size = self .hparams_vision .get ("hidden_size" )
9839+ if mlp_ratio is not None and hidden_size is not None :
9840+ self .hparams_vision ["intermediate_size" ] = int (round (hidden_size * mlp_ratio ))
9841+
9842+ def set_gguf_parameters (self ):
9843+ super ().set_gguf_parameters ()
9844+ assert self .hparams_vision is not None
9845+
9846+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .JANUS_PRO )
9847+
9848+ self .gguf_writer .add_vision_attention_layernorm_eps (self .hparams_vision .get ("layer_norm_eps" , 1e-6 ))
9849+
9850+ hidden_act = str (self .hparams_vision .get ("hidden_act" , "" )).lower ()
9851+ if hidden_act == "gelu" :
9852+ self .gguf_writer .add_vision_use_gelu (True )
9853+ elif hidden_act == "silu" :
9854+ self .gguf_writer .add_vision_use_silu (True )
9855+
9856+ def _map_aligner_tensor (self , data_torch : Tensor , name : str ) -> Iterable [tuple [str , Tensor ]]:
9857+ """Map aligner tensors to projector format"""
9858+ suffix = ".bias" if name .endswith (".bias" ) else ".weight"
9859+
9860+ if name .startswith ("model.aligner." ):
9861+ local_name = name [len ("model.aligner." ):]
9862+ elif name .startswith ("aligner." ):
9863+ local_name = name [len ("aligner." ):]
9864+ else :
9865+ raise ValueError (f"Unsupported Janus aligner prefix: { name } " )
9866+
9867+ if local_name .startswith ("fc1." ):
9868+ mm_index = 0
9869+ elif local_name .startswith ("hidden_layers." ):
9870+ parts = local_name .split ("." , 2 )
9871+ if len (parts ) < 3 :
9872+ raise ValueError (f"Unexpected Janus aligner tensor name: { name } " )
9873+ mm_index = int (parts [1 ]) + 1
9874+ else :
9875+ raise ValueError (f"Unsupported Janus aligner tensor: { name } " )
9876+
9877+ tensor_name = self .format_tensor_name (gguf .MODEL_TENSOR .V_MMPROJ , mm_index , suffix = suffix )
9878+ return [(tensor_name , data_torch )]
9879+
9880+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
9881+ del bid # unused
9882+
9883+ # Skip language model tensors as they will be handled by `JanusProModel`
9884+ if name .startswith (('model.language_model.' , 'language_model.' )):
9885+ return []
9886+
9887+ # Skip generation-related components
9888+ skip_generation_prefixes = (
9889+ 'model.vqmodel.' ,
9890+ 'vqmodel.' ,
9891+ 'model.generation_embeddings.' ,
9892+ 'generation_embeddings.' ,
9893+ 'model.generation_aligner.' ,
9894+ 'generation_aligner.' ,
9895+ 'model.generation_head.' ,
9896+ 'generation_head.' ,
9897+ )
9898+ if name .startswith (skip_generation_prefixes ):
9899+ return []
9900+
9901+ # Handle aligner tensors
9902+ if name .startswith (('model.aligner.' , 'aligner.' )):
9903+ return list (self ._map_aligner_tensor (data_torch , name ))
9904+
9905+ # Handle vision tensors
9906+ if name .startswith (('model.vision_model.' , 'vision_model.' )):
9907+ return [(self .map_tensor_name (name ), data_torch )]
9908+
9909+ return []
9910+
9911+
98059912###### CONVERSION LOGIC ######
98069913
98079914
0 commit comments