@@ -121,20 +121,76 @@ def add_weight_and_scale_mapping(
121121 if weight_scale_mlc_name in named_parameters :
122122 weight_scale_hf_names = [f"{ name } _scale_inv" for name in weight_hf_names ]
123123 weight_scale_param = named_parameters [weight_scale_mlc_name ]
124+ expected_weight_scale_shape = tuple (int (dim ) for dim in weight_scale_param .shape )
125+
126+ def _weight_scale_transform (* arrays , dtype : str , _transform = weight_transform_func ):
127+ processed = []
128+ for arr in arrays :
129+ arr_np = np .asarray (arr )
130+ if arr_np .ndim == 0 :
131+ arr_np = arr_np .reshape ((1 ,))
132+ processed .append (arr_np )
133+ result = _transform (* processed , dtype = dtype )
134+ result = np .asarray (result , dtype = dtype )
135+ if result .shape == expected_weight_scale_shape :
136+ return result
137+ if result .shape == ():
138+ return np .full (expected_weight_scale_shape , result .item (), dtype = dtype )
139+ if result .shape == (1 ,) and expected_weight_scale_shape != (1 ,):
140+ return np .broadcast_to (result , expected_weight_scale_shape ).astype (dtype )
141+ if (
142+ result .ndim == 1
143+ and result .size > 1
144+ and len (expected_weight_scale_shape ) >= 2
145+ and expected_weight_scale_shape [0 ] % result .size == 0
146+ ):
147+ rows_per_segment = expected_weight_scale_shape [0 ] // result .size
148+ tiled = np .repeat (result , rows_per_segment )
149+ tiled = tiled .reshape (expected_weight_scale_shape [0 ], 1 )
150+ return np .broadcast_to (tiled , expected_weight_scale_shape ).astype (dtype )
151+ raise ValueError (
152+ f"Unexpected weight scale shape { result .shape } for "
153+ f"{ weight_scale_mlc_name } , expected { expected_weight_scale_shape } "
154+ )
124155 mapping .add_mapping (
125156 weight_scale_mlc_name ,
126157 weight_scale_hf_names ,
127- functools .partial (weight_transform_func , dtype = weight_scale_param .dtype ),
158+ functools .partial (_weight_scale_transform , dtype = weight_scale_param .dtype ),
128159 )
129160 activation_scale_mlc_name = f"{ weight_mlc_name [: - len ('.weight' )]} .activation_scale"
130161 if activation_scale_mlc_name in named_parameters :
131162 activation_scale_hf_names = [f"{ name [: - len ('.weight' )]} .activation_scale" for name in weight_hf_names ]
132163 activation_scale_param = named_parameters [activation_scale_mlc_name ]
133164 transform = activation_transform_func or weight_transform_func
165+ expected_shape = tuple (int (dim ) for dim in activation_scale_param .shape )
166+
167+ def _activation_scale_transform (* arrays , dtype : str , _transform = transform ):
168+ result = _transform (* arrays , dtype = dtype )
169+ result = np .asarray (result , dtype = dtype )
170+ if result .shape == expected_shape :
171+ return result
172+ if result .shape == ():
173+ # HF checkpoint stores a single scale; broadcast across the expected dimension.
174+ return np .full (expected_shape , result .item (), dtype = dtype )
175+ if result .shape == (1 ,) and expected_shape != (1 ,):
176+ return np .broadcast_to (result , expected_shape ).astype (dtype )
177+ if (
178+ result .ndim == 1
179+ and result .size > 1
180+ and len (expected_shape ) >= 1
181+ and expected_shape [0 ] % result .size == 0
182+ ):
183+ rows_per_segment = expected_shape [0 ] // result .size
184+ tiled = np .repeat (result , rows_per_segment )
185+ return tiled .reshape (expected_shape ).astype (dtype )
186+ raise ValueError (
187+ f"Unexpected activation scale shape { result .shape } for "
188+ f"{ activation_scale_mlc_name } , expected { expected_shape } "
189+ )
134190 mapping .add_mapping (
135191 activation_scale_mlc_name ,
136192 activation_scale_hf_names ,
137- functools .partial (transform , dtype = activation_scale_param .dtype ),
193+ functools .partial (_activation_scale_transform , dtype = activation_scale_param .dtype ),
138194 )
139195
140196 def identity_transform (param : np .ndarray , dtype : str ):
0 commit comments