Skip to content

Commit 0f7b763

Browse files
committed
Add modules_to_not_convert and fix activation scale name
1 parent ce48d38 commit 0f7b763

File tree

4 files changed

+33
-17
lines changed

4 files changed

+33
-17
lines changed

python/mlc_llm/model/ministral3/ministral3_loader.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def huggingface(model_config: Ministral3Config, quantization: Quantization) -> E
6767
model = quantization.quantize_model(model, QuantizeMapping({}, {}), "")
6868
if model_config.weight_block_size is None:
6969
raise ValueError(
70-
"The input DeepSeek model is not fp8 block quantized. "
70+
"The input Ministral 3 model is not fp8 block quantized. "
7171
"Thus BlockScaleQuantize is not supported."
7272
)
7373

@@ -98,7 +98,7 @@ def hf(name: str) -> str:
9898
and model_config.weight_block_size is not None
9999
):
100100
raise ValueError(
101-
"The input DeepSeek model is fp8 block quantized. "
101+
"The input Ministral 3 model is fp8 block quantized. "
102102
"Please use BlockScaleQuantize for the model."
103103
)
104104

@@ -126,9 +126,9 @@ def add_weight_and_scale_mapping(
126126
weight_scale_hf_names,
127127
functools.partial(weight_transform_func, dtype=weight_scale_param.dtype),
128128
)
129-
activation_scale_mlc_name = f"{weight_mlc_name}_activation_scale"
129+
activation_scale_mlc_name = f"{weight_mlc_name[: -len('.weight')]}.activation_scale"
130130
if activation_scale_mlc_name in named_parameters:
131-
activation_scale_hf_names = [f"{name}_activation_scale" for name in weight_hf_names]
131+
activation_scale_hf_names = [f"{name[: -len('.weight')]}.activation_scale" for name in weight_hf_names]
132132
activation_scale_param = named_parameters[activation_scale_mlc_name]
133133
transform = activation_transform_func or weight_transform_func
134134
mapping.add_mapping(
@@ -140,9 +140,6 @@ def add_weight_and_scale_mapping(
140140
def identity_transform(param: np.ndarray, dtype: str):
141141
return param.astype(dtype)
142142

143-
def concat_along_dim0(*arrays: np.ndarray, dtype: str):
144-
return np.concatenate(arrays, axis=0).astype(dtype)
145-
146143
def make_shared_activation_transform(target_name: str):
147144
def func(first: np.ndarray, *rest: np.ndarray, dtype: str):
148145
for idx, arr in enumerate(rest, start=1):

python/mlc_llm/model/ministral3/ministral3_model.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class Ministral3Config(ConfigBase): # pylint: disable=too-many-instance-attribu
4646
tie_word_embeddings: bool = False
4747
weight_block_size: Optional[Tuple[int, int]] = None
4848
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
49+
modules_to_not_convert: Tuple[str, ...] = dataclasses.field(default_factory=tuple)
4950

5051
@classmethod
5152
def from_dict( # type: ignore[override]
@@ -72,6 +73,9 @@ def __post_init__(self): # pylint: disable=too-many-branches
7273
quant_method = quantization_config.get("quant_method", "")
7374
fmt = quantization_config.get("fmt", "")
7475
weight_block_size = quantization_config.get("weight_block_size")
76+
modules_to_not_convert = quantization_config.get("modules_to_not_convert", [])
77+
if isinstance(modules_to_not_convert, list):
78+
self.modules_to_not_convert = tuple(modules_to_not_convert)
7579
if (
7680
quant_method == "fp8"
7781
and fmt == "e4m3"
@@ -317,6 +321,7 @@ def __init__(self, config: Ministral3Config):
317321
self.tie_word_embeddings = config.tie_word_embeddings
318322
if not config.tie_word_embeddings:
319323
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # "vocab_size"
324+
self._mark_modules_no_quant(config.modules_to_not_convert)
320325
self.num_hidden_layers = config.num_hidden_layers
321326
self.num_attention_heads = config.num_attention_heads
322327
self.num_key_value_heads = config.num_key_value_heads
@@ -330,6 +335,20 @@ def __init__(self, config: Ministral3Config):
330335
self.dtype = config.dtype
331336
self.weight_block_size = config.weight_block_size
332337

338+
def _mark_modules_no_quant(self, modules: Tuple[str, ...]):
339+
for path in modules:
340+
if not path:
341+
continue
342+
parts = path.split(".")
343+
target = self
344+
for part in parts:
345+
if not hasattr(target, part):
346+
target = None
347+
break
348+
target = getattr(target, part)
349+
if target is not None:
350+
setattr(target, "no_quantization", True)
351+
333352
def to(self, dtype: Optional[str] = None):
334353
super().to(dtype=dtype)
335354
if dtype is not None:

python/mlc_llm/quantization/block_scale_quantization.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def visit_module(self, name: str, node: nn.Module) -> Any:
159159
and not is_moe_gate(name, node)
160160
):
161161
if self.config.use_activation_scale:
162-
return BlockScaleQuantizeLinearMinistral3.from_linear(
162+
return BlockScaleQuantizeLinearStaticActivation.from_linear(
163163
node, self.config, weight_block_size
164164
)
165165
return BlockScaleQuantizeLinear.from_linear(
@@ -329,8 +329,8 @@ def to(self, dtype: Optional[str] = None) -> None:
329329
self.dtype = dtype # pylint: disable=attribute-defined-outside-init
330330

331331

332-
class BlockScaleQuantizeLinearMinistral3(BlockScaleQuantizeLinear):
333-
"""Block-scale quantization for Ministral3 static activation FP8."""
332+
class BlockScaleQuantizeLinearStaticActivation(BlockScaleQuantizeLinear):
333+
"""Block-scale quantization for static activation FP8."""
334334

335335
def __init__( # pylint: disable=too-many-arguments
336336
self,
@@ -357,9 +357,9 @@ def __init__( # pylint: disable=too-many-arguments
357357
@staticmethod
358358
def from_linear(
359359
src: nn.Linear, config: BlockScaleQuantize, weight_block_size: Optional[Tuple[int, int]]
360-
) -> "BlockScaleQuantizeLinearMinistral3":
360+
) -> "BlockScaleQuantizeLinearStaticActivation":
361361
"""
362-
Convert a non-quantized nn.Linear to a block-scale quantized BlockScaleQuantizeLinearMinistral3.
362+
Convert a non-quantized nn.Linear to a block-scale quantized BlockScaleQuantizeLinearStaticActivation.
363363
364364
Parameters
365365
----------
@@ -374,12 +374,12 @@ def from_linear(
374374
375375
Returns
376376
-------
377-
ret : BlockScaleQuantizeLinearMinistral3
378-
The block-scale quantized BlockScaleQuantizeLinearMinistral3
377+
ret : BlockScaleQuantizeLinearStaticActivation
378+
The block-scale quantized BlockScaleQuantizeLinearStaticActivation
379379
"""
380380
assert weight_block_size is not None
381381
out_features, in_features = src.weight.shape
382-
quantized_linear = BlockScaleQuantizeLinearMinistral3(
382+
quantized_linear = BlockScaleQuantizeLinearStaticActivation(
383383
in_features=in_features,
384384
out_features=out_features,
385385
weight_dtype=config.weight_dtype,

python/mlc_llm/quantization/quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,8 +191,8 @@ def quantize_weight(self, weight: tvm.runtime.Tensor) -> List[tvm.runtime.Tensor
191191
weight_dtype="float8_e4m3fn",
192192
model_dtype="bfloat16",
193193
),
194-
"fp8_e4m3fn_bf16_block_scale_ministral3": BlockScaleQuantize(
195-
name="fp8_e4m3fn_bf16_block_scale_ministral3",
194+
"fp8_e4m3fn_bf16_block_scale_static_activation": BlockScaleQuantize(
195+
name="fp8_e4m3fn_bf16_block_scale_static_activation",
196196
kind="block-scale-quant",
197197
weight_dtype="float8_e4m3fn",
198198
model_dtype="bfloat16",

0 commit comments

Comments
 (0)