Skip to content

Commit 0d02d18

Browse files
committed
update Ln-norm logics for upcoming PyTorch update (#206)
1 parent 7ccd58e commit 0d02d18

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

modelopt/torch/nas/plugins/megatron.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -622,11 +622,15 @@ def _estimate_all_head_importance(self) -> TracedHp.Importance:
622622
def _estimate_query_group_importance(self) -> TracedHp.Importance:
623623
"""Return the importance of the ``num_query_groups`` hparam."""
624624
assert self._activations is not None, "No activations collected for importance estimation."
625-
group_importance = self._activations.view(
626-
self.get_hparam("num_heads_per_group").max,
627-
self.get_hparam("num_query_groups").max,
628-
self.config.kv_channels,
629-
).norm(p=2, dim=(0, 2))
625+
group_importance = torch.linalg.norm(
626+
self._activations.view(
627+
self.get_hparam("num_heads_per_group").max,
628+
self.get_hparam("num_query_groups").max,
629+
self.config.kv_channels,
630+
),
631+
ord=2,
632+
dim=(0, 2),
633+
)
630634
return group_importance
631635

632636
def export(self) -> torch.nn.Module:

modelopt/torch/nas/plugins/transformers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def configure_qkv_out(self, q_name: str, k_name: str, v_name: str, out_name: str
122122
out.in_features = hp_hidden_dim
123123

124124
assert isinstance(out, nn.Linear)
125-
hp_hidden_dim.register_importance(lambda: out._parameters["weight"].detach().norm(dim=0))
125+
hp_hidden_dim.register_importance(
126+
lambda: torch.linalg.norm(out._parameters["weight"].detach(), dim=0)
127+
)
126128

127129
def modify(
128130
self, *, n_heads_ratio: tuple[float, ...] | None = None, n_heads_divisor: int = 1

0 commit comments

Comments
 (0)