Skip to content

Commit 8cf516e

Browse files
[Bug-fix] Remove force_assign in sort_parameters to avoid re-sorting (#514)
## What does this PR do? **Type of change:** Bug-fix <!-- Use one of the following: Bug fix, new feature, new example, new tests, documentation. --> **Overview:** - So far for FastNAS/Gradnas sorting, we `force_assign` parameters to sorted but we retain `hp.active_slice` so next time we fetch `mod.weight` we will apply sorting on already sorted weight tensor resulting in incorrect ordering. This is a bug applicable only for FastNAS/GradNAS modules because Minitron already had a remedy for this by resetting order to None during export ## Testing - Tests pass - Manually verified Signed-off-by: Keval Morabia <[email protected]>
1 parent fcbdc31 commit 8cf516e

File tree

7 files changed

+12
-38
lines changed

7 files changed

+12
-38
lines changed

CHANGELOG.rst

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
Model Optimizer Changelog (Linux)
22
=================================
33

4-
0.39 (2025-11-07)
4+
0.40 (2025-12-xx)
5+
^^^^^^^^^^^^^^^^^
6+
7+
**Bug Fixes**
8+
9+
- Fix a bug in FastNAS pruning (computer vision models) where the model parameters were sorted twice messing up the ordering.
10+
11+
0.39 (2025-11-14)
512
^^^^^^^^^^^^^^^^^
613

714
**Deprecations**

modelopt/torch/nas/modules/conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _estimate_importance(self) -> TracedHp.Importance:
137137
# for group > 1, we do not know how to handle it yet
138138
if self.groups > 1:
139139
return None
140-
weight = self._parameters["weight"] # retrieve full weight tensor
140+
weight = self.weight
141141
c_in = weight.shape[1]
142142
return torch.linalg.vector_norm(
143143
torch.reshape(weight.detach().transpose(0, 1), (c_in, -1)), dim=1
@@ -249,6 +249,6 @@ def _estimate_importance(self) -> TracedHp.Importance:
249249
# for group > 1, we do not know how to handle it yet
250250
if self.groups > 1:
251251
return None
252-
weight = self._parameters["weight"] # retrieve full weight tensor
252+
weight = self.weight
253253
c_in = weight.shape[0]
254254
return torch.linalg.vector_norm(torch.reshape(weight.detach(), (c_in, -1)), dim=1)

modelopt/torch/nas/modules/linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _get_bias(mod: "_DynamicLinear", bias: torch.Tensor | None) -> torch.Tensor
4141
return get_sliced_tensor(mod, bias, "out_features")
4242

4343
def _estimate_importance(self) -> TracedHp.Importance:
44-
return torch.linalg.vector_norm(self._parameters["weight"].detach(), dim=0)
44+
return torch.linalg.vector_norm(self.weight.detach(), dim=0)
4545

4646
def _setup(self):
4747
# register hyperparameters

modelopt/torch/nas/plugins/megatron.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@
5151
from modelopt.torch.opt.dynamic import DynamicModule
5252
from modelopt.torch.opt.hparam import HPType
5353
from modelopt.torch.opt.searcher import ConstraintsDict
54-
from modelopt.torch.opt.utils import named_hparams
5554
from modelopt.torch.trace import Symbol
5655
from modelopt.torch.utils import distributed as dist
5756
from modelopt.torch.utils import (
@@ -1322,12 +1321,6 @@ def _export_drop_layers(self) -> None:
13221321

13231322
def export(self) -> torch.nn.Module:
13241323
"""Export the dynamic module to a torch.nn.Module."""
1325-
# TODO: Improve this!
1326-
# Slice order needs to be reset before exporting since weights are already
1327-
# force assigned and we dont want to sort them again (losing the correct order)
1328-
for n, hp in named_hparams(self, configurable=True):
1329-
hp.enforce_order(None)
1330-
13311324
for handle in self.hook_handles:
13321325
handle.remove()
13331326
self._export_drop_layers()

modelopt/torch/nas/plugins/transformers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def configure_qkv_out(self, q_name: str, k_name: str, v_name: str, out_name: str
123123

124124
assert isinstance(out, nn.Linear)
125125
hp_hidden_dim.register_importance(
126-
lambda: torch.linalg.vector_norm(out._parameters["weight"].detach(), dim=0)
126+
lambda: torch.linalg.vector_norm(out.weight.detach(), dim=0)
127127
)
128128

129129
def modify(

modelopt/torch/nas/search_space.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -162,10 +162,6 @@ def sort_parameters(self, hps_to_sort: set[str] | None = None, verbose: bool = F
162162
f"{'order' if hp._importance_is_order else 'importance'}={importance}"
163163
)
164164

165-
# now that we have enforced an order we can force reassign all parameters/buffers!
166-
for _, mod in self.named_dynamic_modules():
167-
mod.force_assign()
168-
169165
# go back to old config
170166
self.select(config)
171167

modelopt/torch/opt/dynamic.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -586,28 +586,6 @@ def export(self) -> nn.Module:
586586

587587
return self
588588

589-
@torch.no_grad()
590-
def force_assign(self):
591-
"""Force re-assign all dynamic attributes to their current values.
592-
593-
.. warning::
594-
595-
Note that this method overwrites the actual buffers and parameters! Only use in
596-
specific circumstances!!
597-
"""
598-
# force-reassign all dynamic attributes
599-
for name in self._get_dm_attribute_manager().da_keys():
600-
val = getattr(self, name)
601-
if isinstance(val, torch.Tensor):
602-
val = val.detach().clone()
603-
if name in self._parameters:
604-
val = val if val is None else Parameter(val)
605-
self.register_parameter(name, val)
606-
elif name in self._buffers:
607-
self.register_buffer(name, val)
608-
else:
609-
setattr(self, name, val)
610-
611589
@classmethod
612590
@torch.no_grad()
613591
def convert(cls, module: nn.Module) -> "DynamicModule":

0 commit comments

Comments
 (0)