Skip to content
6 changes: 4 additions & 2 deletions modelopt/torch/nas/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,9 @@ def _estimate_importance(self) -> TracedHp.Importance:
return None
weight = self._parameters["weight"] # retrieve full weight tensor
c_in = weight.shape[1]
return torch.norm(torch.reshape(weight.detach().transpose(0, 1), (c_in, -1)), dim=1)
return torch.linalg.vector_norm(
torch.reshape(weight.detach().transpose(0, 1), (c_in, -1)), dim=1
)

def _setup(self):
# only support ungrouped conv or grouped conv with in_channels == out_channels
Expand Down Expand Up @@ -249,4 +251,4 @@ def _estimate_importance(self) -> TracedHp.Importance:
return None
weight = self._parameters["weight"] # retrieve full weight tensor
c_in = weight.shape[0]
return torch.norm(torch.reshape(weight.detach(), (c_in, -1)), dim=1)
return torch.linalg.vector_norm(torch.reshape(weight.detach(), (c_in, -1)), dim=1)
2 changes: 1 addition & 1 deletion modelopt/torch/nas/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _get_bias(mod: "_DynamicLinear", bias: torch.Tensor | None) -> torch.Tensor
return get_sliced_tensor(mod, bias, "out_features")

def _estimate_importance(self) -> TracedHp.Importance:
return self._parameters["weight"].detach().norm(dim=0)
return torch.linalg.vector_norm(self._parameters["weight"].detach(), dim=0)

def _setup(self):
# register hyperparameters
Expand Down
14 changes: 9 additions & 5 deletions modelopt/torch/nas/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,11 +622,15 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:
def _estimate_query_group_importance(self) -> TracedHp.Importance:
"""Return the importance of the ``num_query_groups`` hparam."""
assert self._activations is not None, "No activations collected for importance estimation."
group_importance = self._activations.view(
self.get_hparam("num_heads_per_group").max,
self.get_hparam("num_query_groups").max,
self.config.kv_channels,
).norm(p=2, dim=(0, 2))
group_importance = torch.linalg.norm(
self._activations.view(
self.get_hparam("num_heads_per_group").max,
self.get_hparam("num_query_groups").max,
self.config.kv_channels,
),
ord=2,
dim=(0, 2),
)
return group_importance

def export(self) -> torch.nn.Module:
Expand Down
4 changes: 3 additions & 1 deletion modelopt/torch/nas/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ def configure_qkv_out(self, q_name: str, k_name: str, v_name: str, out_name: str
out.in_features = hp_hidden_dim

assert isinstance(out, nn.Linear)
hp_hidden_dim.register_importance(lambda: out._parameters["weight"].detach().norm(dim=0))
hp_hidden_dim.register_importance(
lambda: torch.linalg.norm(out._parameters["weight"].detach(), dim=0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use vector_norm here as well?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Jun 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case computes vector norm because dim is int. Please check above comment; let me know if there is anything wrong in internal CI

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets use vector_norm explicitly to avoid any ambiguity and consistency with rest of the changes

)

def modify(
self, *, n_heads_ratio: tuple[float, ...] | None = None, n_heads_divisor: int = 1
Expand Down
Loading