|
10 | 10 | import wandb
|
11 | 11 |
|
12 | 12 |
|
13 |
| -_FSDP_WRAPPED_MODULE = ["_forward_module.", "_fsdp_wrapped_module."] |
| 13 | +_WRAPPED_NAME_TO_REMOVE = ["_forward_module.", "_fsdp_wrapped_module.", "_orig_mod."] |
14 | 14 |
|
15 | 15 |
|
16 | 16 | def _remove_fsdp_prefix(name: str) -> str:
|
17 |
| - for prefix in _FSDP_WRAPPED_MODULE: |
| 17 | + for prefix in _WRAPPED_NAME_TO_REMOVE: |
18 | 18 | if prefix in name:
|
19 |
| - return name.replace(prefix, "") |
| 19 | + name = name.replace(prefix, "") |
20 | 20 | return name
|
21 | 21 |
|
22 | 22 |
|
| 23 | +@torch.compiler.disable() |
23 | 24 | @torch.no_grad()
|
24 | 25 | def log_activations_hook(
|
25 | 26 | _mod: torch.nn.Module,
|
26 | 27 | _inp: torch.Tensor,
|
27 | 28 | outp: torch.Tensor | tuple[torch.Tensor, ...],
|
28 | 29 | mod_name: str,
|
| 30 | + gradient_accumulation_steps: int, |
29 | 31 | log_activations: dict[str, float],
|
30 | 32 | ) -> None:
|
31 | 33 | if isinstance(outp, tuple):
|
32 | 34 | outp = outp[0]
|
33 |
| - |
34 |
| - norm = outp.norm(p=2) |
35 |
| - |
| 35 | + norm = outp.norm(p=2) / gradient_accumulation_steps |
36 | 36 | name = _remove_fsdp_prefix(mod_name)
|
37 |
| - |
38 | 37 | if f"activation/{name}" not in log_activations:
|
39 | 38 | log_activations[f"activation/{name}"] = norm
|
40 | 39 | else:
|
41 | 40 | log_activations[f"activation/{name}"] += norm
|
42 | 41 |
|
43 | 42 |
|
44 |
| -class ActivationNormMetric: |
| 43 | +def register_metrics_hooks( |
| 44 | + model: torch.nn.Module, |
| 45 | + target_layers: list[str], |
| 46 | + log_activations: dict[str, torch.Tensor], |
| 47 | + gradient_accumulation_steps: int, |
| 48 | +) -> list[RemovableHandle]: |
45 | 49 | """
|
46 |
| - This class is used to monitor the norm of the activation of the target layers. |
47 |
| - It attached hook to the forward of each layer that will log the output, and remove them after. |
| 50 | + this function take a torch module, a list of layer name and apply a hook function that |
| 51 | + monitor the output norm of the layers. |
48 | 52 | """
|
49 |
| - |
50 |
| - def __init__(self, target_layers: list[str], gradient_accumulation_steps: int): |
51 |
| - self.target_layers = target_layers |
52 |
| - self.handles: list[RemovableHandle] = [] |
53 |
| - self._log_activations: dict[str, torch.Tensor] = {} |
54 |
| - self.gradient_accumulation_steps = gradient_accumulation_steps |
55 |
| - |
56 |
| - def register_metrics_hooks(self, model: torch.nn.Module): |
57 |
| - """ |
58 |
| - this function take a torch module, a list of layer name and apply a hook function that |
59 |
| - monitor the output norm of the layers. |
60 |
| - """ |
61 |
| - handles = [] |
62 |
| - for name, mod in model.named_modules(): |
63 |
| - for layer in self.target_layers: |
64 |
| - if name.endswith(layer): |
65 |
| - handle = mod.register_forward_hook( |
66 |
| - partial(log_activations_hook, log_activations=self._log_activations, mod_name=name) |
| 53 | + handles = [] |
| 54 | + for name, mod in model.named_modules(): |
| 55 | + for layer in target_layers: |
| 56 | + if name.endswith(layer): |
| 57 | + handle = mod.register_forward_hook( |
| 58 | + partial( |
| 59 | + log_activations_hook, |
| 60 | + log_activations=log_activations, |
| 61 | + mod_name=name, |
| 62 | + gradient_accumulation_steps=gradient_accumulation_steps, |
67 | 63 | )
|
68 |
| - handles.append(handle) |
69 |
| - break |
70 |
| - |
71 |
| - self.handles = handles |
72 |
| - |
73 |
| - def remove_hooks(self) -> None: |
74 |
| - for handle in self.handles: |
75 |
| - handle.remove() |
| 64 | + ) |
| 65 | + handles.append(handle) |
76 | 66 |
|
77 |
| - @property |
78 |
| - def log_activations(self) -> dict[str, torch.Tensor]: |
79 |
| - return {k: v / self.gradient_accumulation_steps for k, v in self._log_activations.items()} |
| 67 | + return handles |
80 | 68 |
|
81 | 69 |
|
82 | 70 | def _round_str(x: float):
|
|
0 commit comments