@@ -102,7 +102,7 @@ def _replace_linear_with_linear_8da4w_for_spin_quant(
102102):
103103 def filter_fn (child : torch .nn .Module , cur_fqn : str ) -> bool :
104104 # Only replace linear layers where the checkpoint contains explicit scales
105- scales_key = f"{ cur_fqn } .scale "
105+ scales_key = f"{ cur_fqn } .scales "
106106 if isinstance (child , nn .Linear ) and scales_key in checkpoint :
107107 assert _check_linear_int4_k (child .in_features , group_size )
108108 assert checkpoint [f"{ cur_fqn } .weight" ].dtype == torch .int8
@@ -155,7 +155,7 @@ def _replace_output_linear_with_linear_int8_for_spinquant(
155155 dtype : torch .dtype ,
156156):
157157 def filter_fn (child : torch .nn .Module , cur_fqn : str ) -> bool :
158- scales_key = f"{ cur_fqn } .scale "
158+ scales_key = f"{ cur_fqn } .scales "
159159 if (
160160 isinstance (child , nn .Linear )
161161 and scales_key in checkpoint
@@ -205,7 +205,7 @@ def _replace_embedding_with_quantized_group_embedding_for_spinquant(
205205):
206206 def filter_fn (child : torch .nn .Module , cur_fqn : str ) -> bool :
207207 # Only replace embedding layers where the checkpoint contains explicit scales
208- scales_key = f"{ cur_fqn } .scale "
208+ scales_key = f"{ cur_fqn } .scales "
209209 if isinstance (child , nn .Embedding ) and scales_key in checkpoint :
210210 assert checkpoint [f"{ cur_fqn } .weight" ].dtype == torch .int8
211211 assert checkpoint [scales_key ].dtype == torch .float32
@@ -250,59 +250,12 @@ def transform_embedding_for_spinquant(
250250
251251
252252def sanitize_checkpoint_from_spinquant (
253- module : torch .nn .Module ,
254253 checkpoint : Any ,
255- linear_group_size : int ,
256- embedding_group_size : Optional [int ] = None ,
257254):
258255 """
259256 Sanitize the SpinQuant checkpoint.
260- - Renames 'scale' to 'scales'
261- - Groups scales
262- - Removes 'o_weight'
263257 - Converts all tensors to contiguous format
258+ - Squeeze all tensors
264259 """
265- keys_to_rename = []
266- keys_to_remove = []
267- for k , _ in checkpoint .items ():
268- if k .endswith (".scale" ):
269- new_key = k + "s"
270- keys_to_rename .append ((k , new_key ))
271- if k .endswith (".o_weight" ):
272- keys_to_remove .append (k )
273-
274- for old_key , new_key in keys_to_rename :
275- old_val = checkpoint .pop (old_key )
276- module_name = new_key [0 : new_key .rfind ("." )]
277- sub_module = module .get_submodule (module_name )
278- assert sub_module is not None
279- assert (
280- isinstance (sub_module , Int8DynActInt4WeightLinear )
281- or isinstance (sub_module , QuantizedGroupEmbedding )
282- or isinstance (sub_module , Int8DynActInt8WeightLinear )
283- )
284- # Checkpoints with SpinQuant could come with two formats for scales:
285- # 1. scales is grouped by group size
286- # 2. scales is not grouped by group size
287- # We need to handle both cases here.
288- # TODO(lunwenh): remove this once we have a unified format for scales.
289- if isinstance (sub_module , Int8DynActInt4WeightLinear ):
290- checkpoint [new_key ] = (
291- old_val if linear_group_size == - 1 else old_val [:, ::linear_group_size ]
292- )
293- elif isinstance (sub_module , Int8DynActInt8WeightLinear ):
294- checkpoint [new_key ] = old_val [:, 0 ]
295- elif isinstance (sub_module , QuantizedGroupEmbedding ):
296- if (
297- embedding_group_size is None or embedding_group_size == 0
298- ): # Scales are not grouped
299- checkpoint [new_key ] = old_val [:, 0 ]
300- elif embedding_group_size == - 1 : # Scales are grouped by group size
301- checkpoint [new_key ] = old_val
302- else :
303- checkpoint [new_key ] = old_val [:, ::embedding_group_size ]
304-
305- for k in keys_to_remove :
306- checkpoint .pop (k )
307260 for k , v in checkpoint .items ():
308- checkpoint [k ] = v .contiguous ()
261+ checkpoint [k ] = torch . squeeze ( v .contiguous () )
0 commit comments