Skip to content

Commit e51af61

Browse files
committed
Update scale transforms
1 parent 0f7b763 commit e51af61

File tree

4 files changed

+66
-11
lines changed

4 files changed

+66
-11
lines changed

python/mlc_llm/conversation_template/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
hermes,
1818
llama,
1919
llava,
20+
ministral3,
2021
mistral,
2122
nemotron,
2223
oasst,

python/mlc_llm/model/ministral3/ministral3_loader.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,20 +121,76 @@ def add_weight_and_scale_mapping(
121121
if weight_scale_mlc_name in named_parameters:
122122
weight_scale_hf_names = [f"{name}_scale_inv" for name in weight_hf_names]
123123
weight_scale_param = named_parameters[weight_scale_mlc_name]
124+
expected_weight_scale_shape = tuple(int(dim) for dim in weight_scale_param.shape)
125+
126+
def _weight_scale_transform(*arrays, dtype: str, _transform=weight_transform_func):
127+
processed = []
128+
for arr in arrays:
129+
arr_np = np.asarray(arr)
130+
if arr_np.ndim == 0:
131+
arr_np = arr_np.reshape((1,))
132+
processed.append(arr_np)
133+
result = _transform(*processed, dtype=dtype)
134+
result = np.asarray(result, dtype=dtype)
135+
if result.shape == expected_weight_scale_shape:
136+
return result
137+
if result.shape == ():
138+
return np.full(expected_weight_scale_shape, result.item(), dtype=dtype)
139+
if result.shape == (1,) and expected_weight_scale_shape != (1,):
140+
return np.broadcast_to(result, expected_weight_scale_shape).astype(dtype)
141+
if (
142+
result.ndim == 1
143+
and result.size > 1
144+
and len(expected_weight_scale_shape) >= 2
145+
and expected_weight_scale_shape[0] % result.size == 0
146+
):
147+
rows_per_segment = expected_weight_scale_shape[0] // result.size
148+
tiled = np.repeat(result, rows_per_segment)
149+
tiled = tiled.reshape(expected_weight_scale_shape[0], 1)
150+
return np.broadcast_to(tiled, expected_weight_scale_shape).astype(dtype)
151+
raise ValueError(
152+
f"Unexpected weight scale shape {result.shape} for "
153+
f"{weight_scale_mlc_name}, expected {expected_weight_scale_shape}"
154+
)
124155
mapping.add_mapping(
125156
weight_scale_mlc_name,
126157
weight_scale_hf_names,
127-
functools.partial(weight_transform_func, dtype=weight_scale_param.dtype),
158+
functools.partial(_weight_scale_transform, dtype=weight_scale_param.dtype),
128159
)
129160
activation_scale_mlc_name = f"{weight_mlc_name[: -len('.weight')]}.activation_scale"
130161
if activation_scale_mlc_name in named_parameters:
131162
activation_scale_hf_names = [f"{name[: -len('.weight')]}.activation_scale" for name in weight_hf_names]
132163
activation_scale_param = named_parameters[activation_scale_mlc_name]
133164
transform = activation_transform_func or weight_transform_func
165+
expected_shape = tuple(int(dim) for dim in activation_scale_param.shape)
166+
167+
def _activation_scale_transform(*arrays, dtype: str, _transform=transform):
168+
result = _transform(*arrays, dtype=dtype)
169+
result = np.asarray(result, dtype=dtype)
170+
if result.shape == expected_shape:
171+
return result
172+
if result.shape == ():
173+
# HF checkpoint stores a single scale; broadcast across the expected dimension.
174+
return np.full(expected_shape, result.item(), dtype=dtype)
175+
if result.shape == (1,) and expected_shape != (1,):
176+
return np.broadcast_to(result, expected_shape).astype(dtype)
177+
if (
178+
result.ndim == 1
179+
and result.size > 1
180+
and len(expected_shape) >= 1
181+
and expected_shape[0] % result.size == 0
182+
):
183+
rows_per_segment = expected_shape[0] // result.size
184+
tiled = np.repeat(result, rows_per_segment)
185+
return tiled.reshape(expected_shape).astype(dtype)
186+
raise ValueError(
187+
f"Unexpected activation scale shape {result.shape} for "
188+
f"{activation_scale_mlc_name}, expected {expected_shape}"
189+
)
134190
mapping.add_mapping(
135191
activation_scale_mlc_name,
136192
activation_scale_hf_names,
137-
functools.partial(transform, dtype=activation_scale_param.dtype),
193+
functools.partial(_activation_scale_transform, dtype=activation_scale_param.dtype),
138194
)
139195

140196
def identity_transform(param: np.ndarray, dtype: str):

python/mlc_llm/model/ministral3/ministral3_model.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,14 +71,12 @@ def __post_init__(self): # pylint: disable=too-many-branches
7171
if isinstance(quantization_config, dict):
7272
activation_scheme = quantization_config.get("activation_scheme", "")
7373
quant_method = quantization_config.get("quant_method", "")
74-
fmt = quantization_config.get("fmt", "")
7574
weight_block_size = quantization_config.get("weight_block_size")
7675
modules_to_not_convert = quantization_config.get("modules_to_not_convert", [])
7776
if isinstance(modules_to_not_convert, list):
7877
self.modules_to_not_convert = tuple(modules_to_not_convert)
7978
if (
8079
quant_method == "fp8"
81-
and fmt == "e4m3"
8280
and activation_scheme == "static"
8381
and weight_block_size is not None
8482
):
@@ -95,16 +93,17 @@ def __post_init__(self): # pylint: disable=too-many-branches
9593
else:
9694
self.weight_block_size = [128, 128]
9795
logger.info(
98-
"Setting default weight_block_size since quantization_config does not provide "
99-
"FP8 block-scale details required by MLC (activation_scheme=%s, quant_method=%s, "
100-
"fmt=%s, weight_block_size=%s)",
96+
"Setting default weight_block_size=%s since quantization_config does not provide "
97+
"FP8 block-scale details required by MLC (activation_scheme=%s, quant_method=%s)",
98+
self.weight_block_size,
10199
activation_scheme,
102100
quant_method,
103-
fmt,
104-
weight_block_size,
105101
)
106102
else:
107-
logger.info("Ignoring non-dict quantization_config: %s", quantization_config)
103+
raise ValueError(
104+
"Invalid Ministral 3 model quantization config: unrecognized quantization config: "
105+
f"{quantization_config}"
106+
)
108107

109108
if self.position_embedding_base == 0:
110109
if self.rope_parameters is not None and "rope_theta" in self.rope_parameters:

python/mlc_llm/model/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ class Model:
145145
source={
146146
"huggingface-torch": ministral3_loader.huggingface,
147147
"huggingface-safetensor": ministral3_loader.huggingface,
148-
"awq": ministral3_loader.awq,
149148
},
150149
quantize={
151150
"group-quant": ministral3_quantization.group_quant,

0 commit comments

Comments
 (0)