Skip to content

Commit 7036599

Browse files
author
Your Name
committed
address comments
1 parent f529d43 commit 7036599

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

modelopt/torch/opt/dynamic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1273,7 +1273,8 @@ def config(self, configurable: bool | None = None) -> dict[str, Any]:
12731273
A dict of ``(parameter_name, choice)`` that specifies an active subnet.
12741274
"""
12751275
return {
1276-
get_unwrapped_name(name): hp.active for name, hp in self.named_hparams(configurable)
1276+
get_unwrapped_name(name, self): hp.active
1277+
for name, hp in self.named_hparams(configurable)
12771278
}
12781279

12791280
def select(self, config: dict[str, Any], strict: bool = True) -> None:

modelopt/torch/opt/plugins/peft.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def _new_load_adapter(self, model_id, adapter_name, *args, **kwargs):
9090
)
9191
for name, module in self.named_modules():
9292
if isinstance(module, TensorQuantizer):
93-
module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name)])
93+
module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name, self)])
9494

9595
return outputs
9696

modelopt/torch/quantization/conversion.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,12 @@ def restore_quantizer_state(model: nn.Module, config: QuantizeConfig, metadata:
123123

124124
for name, module in model.named_modules():
125125
if isinstance(module, TensorQuantizer):
126-
name = get_unwrapped_name(name)
126+
name = get_unwrapped_name(name, model)
127127
module.set_from_modelopt_state(quantizer_state_dict[name])
128128

129129
for name, module in model.named_modules():
130130
if isinstance(module, QuantModule):
131-
name = get_unwrapped_name(name)
131+
name = get_unwrapped_name(name, model)
132132
module.modelopt_post_restore(name)
133133

134134
return model
@@ -166,7 +166,7 @@ def update_quantize_metadata(
166166
def quantizer_state(model: nn.Module) -> dict[str, Any]:
167167
"""Returns the quantizer state dict describing the quantizer states in the model."""
168168
return {
169-
get_unwrapped_name(n): m.get_modelopt_state()
169+
get_unwrapped_name(n, model): m.get_modelopt_state()
170170
for n, m in model.named_modules()
171171
if isinstance(m, (TensorQuantizer, SequentialQuantizer))
172172
}

modelopt/torch/quantization/utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def get_quantizer_state_dict(model: nn.Module):
452452
quantizer_state_dict = {}
453453
for name, module in model.named_modules():
454454
if isinstance(module, TensorQuantizer):
455-
quantizer_state_dict[get_unwrapped_name(name)] = module.state_dict()
455+
quantizer_state_dict[get_unwrapped_name(name, model)] = module.state_dict()
456456
return quantizer_state_dict
457457

458458

@@ -461,5 +461,8 @@ def set_quantizer_state_dict(model: nn.Module, quantizer_state_dict: dict):
461461
from .nn import TensorQuantizer
462462

463463
for name, module in model.named_modules():
464-
if isinstance(module, TensorQuantizer) and get_unwrapped_name(name) in quantizer_state_dict:
465-
module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name)])
464+
if (
465+
isinstance(module, TensorQuantizer)
466+
and get_unwrapped_name(name, model) in quantizer_state_dict
467+
):
468+
module.load_state_dict(quantizer_state_dict[get_unwrapped_name(name, model)])

0 commit comments

Comments
 (0)