@@ -489,7 +489,7 @@ def prepare_tensors(self):
489489 old_dtype = data_torch .dtype
490490
491491 # convert any unsupported data types to float32
492- if data_torch .dtype not in (torch .float16 , torch .float32 ):
492+ if data_torch .dtype not in (torch .float16 , torch .float32 , torch . int32 ):
493493 data_torch = data_torch .to (torch .float32 )
494494
495495 # use the first number-like part of the tensor name as the block id
@@ -7093,6 +7093,7 @@ def set_gguf_parameters(self):
70937093 self .gguf_writer .add_rope_scaling_yarn_log_mul (0.1 * rope_scaling ["mscale_all_dim" ])
70947094
70957095 _experts : list [dict [str , Tensor ]] | None = None
7096+ _experts_s : list [dict [str , Tensor ]] | None = None # scale (for quantized experts)
70967097
70977098 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
70987099 # skip vision tensors and remove "language_model." for Kimi-VL
@@ -7120,28 +7121,42 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
71207121 if self ._experts is None :
71217122 self ._experts = [{} for _ in range (self .block_count )]
71227123
7123- self ._experts [bid ][name ] = data_torch
7124+ if self ._experts_s is None :
7125+ self ._experts_s = [{} for _ in range (self .block_count )]
71247126
7125- if len ( self . _experts [ bid ]) >= n_experts * 3 :
7126- tensors : list [ tuple [ str , Tensor ]] = []
7127+ if name . endswith ( ".weight_packed" ) :
7128+ self . _experts [ bid ][ name ] = data_torch
71277129
7130+ if name .endswith (".weight_scale" ):
7131+ self ._experts_s [bid ][name ] = data_torch
7132+
7133+ # TODO @ngxson : this is demo, won't compat with other models
7134+ if len (self ._experts [bid ]) + len (self ._experts_s [bid ]) >= n_experts * 3 * 2 :
71287135 # merge the experts into a single 3d tensor
71297136 for w_name in ["down_proj" , "gate_proj" , "up_proj" ]:
71307137 datas : list [Tensor ] = []
7138+ datas_s : list [Tensor ] = []
71317139
71327140 for xid in range (n_experts ):
7133- ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight "
7141+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight_packed "
71347142 datas .append (self ._experts [bid ][ename ])
71357143 del self ._experts [bid ][ename ]
71367144
7137- data_torch = torch .stack (datas , dim = 0 )
7145+ ename = f"model.layers.{ bid } .mlp.experts.{ xid } .{ w_name } .weight_scale"
7146+ datas_s .append (self ._experts_s [bid ][ename ])
7147+ del self ._experts_s [bid ][ename ]
7148+
7149+ data_packed = torch .stack (datas , dim = 0 )
7150+ data_scale = torch .stack (datas_s , dim = 0 )
71387151
71397152 merged_name = f"model.layers.{ bid } .mlp.experts.{ w_name } .weight"
71407153
71417154 new_name = self .map_tensor_name (merged_name )
71427155
7143- tensors .append ((new_name , data_torch ))
7144- return tensors
7156+ target_shape = (n_experts , data_packed .shape [1 ], data_packed .shape [2 ] * 32 )
7157+ self .repack_compressed_tensor (new_name , data_packed , data_scale , target_shape )
7158+ #tensors.append((new_name, data_torch))
7159+ return []
71457160 else :
71467161 return []
71477162
@@ -7176,6 +7191,27 @@ def prepare_tensors(self):
71767191 if len (experts ) > 0 :
71777192 raise ValueError (f"Unprocessed experts: { experts } " )
71787193
7194+ def repack_compressed_tensor (self , new_name : str , blocks : Tensor , scales : Tensor , shape : Sequence [int ]):
7195+ assert blocks .dtype == torch .int32
7196+ assert len (blocks .shape ) == 3
7197+ assert len (scales .shape ) == 3
7198+ logger .info (f"Repacking compressed_tensor { new_name } with shape { shape } " )
7199+ # flatten the first two dimensions
7200+ blocks = blocks .reshape (- 1 , blocks .shape [2 ])
7201+ scales = scales .reshape (- 1 , scales .shape [2 ])
7202+ # TODO: for kimi-k2, this will cast bf16 to f16, this may reduce the accuracy of the model
7203+ # we have to do this because Q4_0 in GGUF only supports f16 scales
7204+ scales = scales .to (torch .float16 )
7205+ scales = scales .to (torch .float16 ).view (torch .uint16 ).reshape (- 1 , 1 )
7206+ repacked = blocks .reshape ((blocks .shape [0 ] * blocks .shape [1 ]) // 4 , 4 )
7207+ repacked = repacked .view (torch .uint16 )
7208+ assert repacked .shape [0 ] == scales .shape [0 ] # should have the same number of blocks
7209+ repacked = torch .concat ([scales , repacked ], dim = 1 )
7210+ repacked = repacked .view (torch .uint8 )
7211+ shape_list = list (shape )
7212+ shape_list [- 1 ] = (shape_list [- 1 ] // 32 ) * 18 # block * 18 bytes for Q4_0 block size
7213+ self .gguf_writer .add_tensor (new_name , repacked .numpy (), raw_dtype = gguf .GGMLQuantizationType .Q4_0 , raw_shape = shape_list )
7214+
71797215
71807216@ModelBase .register ("MiniMaxM2ForCausalLM" )
71817217class MiniMaxM2Model (TextModel ):
0 commit comments