Skip to content

Commit acecdb5

Browse files
Fix svdquant restore
1 parent bf2b0d0 commit acecdb5

File tree

4 files changed

+6
-2
lines changed

4 files changed

+6
-2
lines changed

examples/llm_ptq/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ ChatGLM2, 3 6B | No | No | Yes | No | -
117117
Bloom | Yes | Yes | Yes | Yes | -
118118
Phi-1,2,3,4 | Yes | Yes | Yes | Yes<sup>3</sup> |
119119
Phi-3.5 MOE | Yes | No | No | No | -
120-
Llama-Nemotron Super/Ultra | Yes | No | No | No | Yes
120+
Llama-Nemotron Super | Yes | No | No | No | Yes
121+
Llama-Nemotron Ultra | Yes | No | No | No | No
121122
Nemotron 8B | Yes | No | Yes | No | -
122123
Gemma 2B, 7B | Yes | No | Yes | Yes | -
123124
Gemma 2 9B, 27B | Yes<sup>2</sup> | No | Yes | No | -

modelopt/torch/quantization/conversion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def create_and_replace_svdquant_linear_on_the_fly(model):
143143
def restore_svdquant_model(model: nn.Module, config: QuantizeConfig, metadata: MetadataDict):
144144
"""Restore the svdquant states from the given state dict."""
145145
create_and_replace_svdquant_linear_on_the_fly(model)
146+
restore_quantizer_state(model, config, metadata)
146147
return model
147148

148149

modelopt/torch/speculative/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ class EagleConfig(ModeloptBaseConfig):
134134
ffn_hidden_size: int = ModeloptField(
135135
default=0,
136136
description=(
137-
"ffn_hidden_size of the eagle module. Using base model's ffn_hidden_size is set to None."
137+
"ffn_hidden_size of the eagle module. Using base model's ffn_hidden_size is set to 0."
138138
),
139139
)
140140

tests/_test_utils/torch_quantization/checkpointing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@ def format_modelopt_checkpoint_by_version(modelopt_state: dict, version: str):
2222
if Version(version) >= Version("0.29"):
2323
return modelopt_state
2424
modelopt_state = copy.deepcopy(modelopt_state)
25+
modelopt_state["modelopt_version"] = version
2526
for mode, state in modelopt_state["modelopt_state_dict"]:
2627
if "quantizer_state" not in state["metadata"]:
2728
continue
2829
for quantizer_name, quantizer_state in state["metadata"]["quantizer_state"].items():
30+
quantizer_state["_mopt_ckpt_versn"] = version
2931
pyt_states = quantizer_state.pop("_pytorch_state_metadata", None)
3032
if pyt_states is None:
3133
continue

0 commit comments

Comments
 (0)