Skip to content

Commit 8ce08c0

Browse files
authored
fix torch compile log act (#23)
* fix renaming logic for key * fix stuff * fix exploding norm * remove print
1 parent 35cd120 commit 8ce08c0

File tree

2 files changed

+38
-48
lines changed

2 files changed

+38
-48
lines changed

open_diloco/train_fsdp.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@
5252

5353

5454
from open_diloco.utils import (
55-
ActivationNormMetric,
5655
FakeTokenizedDataset,
5756
get_compression_kwargs,
5857
get_sharding_strategy,
58+
register_metrics_hooks,
5959
)
6060

6161

@@ -353,6 +353,8 @@ def scheduler_fn(opt):
353353
if world_messenger_hv:
354354
max_num_peers = 0
355355

356+
log_activations = {}
357+
356358
for step, batch in enumerate(iterable=train_dataloader, start=start_step * gradient_accumulation_steps):
357359
real_step = (step + 1) // gradient_accumulation_steps
358360
is_accumulating = bool((step + 1) % gradient_accumulation_steps)
@@ -362,11 +364,9 @@ def scheduler_fn(opt):
362364
)
363365

364366
if logging_activations_steps:
365-
activation_monitor = ActivationNormMetric(
366-
target_layers=TARGET_LAYER_ACTIVATIONS,
367-
gradient_accumulation_steps=gradient_accumulation_steps,
367+
handles = register_metrics_hooks(
368+
model, TARGET_LAYER_ACTIVATIONS, log_activations, gradient_accumulation_steps
368369
)
369-
activation_monitor.register_metrics_hooks(model)
370370

371371
for key in batch.keys():
372372
batch[key] = batch[key].to("cuda")
@@ -379,6 +379,10 @@ def scheduler_fn(opt):
379379

380380
scaler.scale(loss).backward()
381381

382+
if logging_activations_steps:
383+
for handle in handles:
384+
handle.remove()
385+
382386
if not is_accumulating:
383387
if world_messenger_hv:
384388
scaler.unscale_(optimizer=optimizer.inner_optimizer)
@@ -400,9 +404,6 @@ def scheduler_fn(opt):
400404
scheduler.step()
401405
optimizer.zero_grad()
402406

403-
if logging_activations_steps:
404-
activation_monitor.remove_hooks()
405-
406407
if config.hv is not None:
407408
if int(real_step) % config.hv.local_steps == 0:
408409
for param in model.parameters():
@@ -442,7 +443,8 @@ def scheduler_fn(opt):
442443
metrics["num_peers"] = num_peers
443444

444445
if logging_activations_steps:
445-
metrics.update(activation_monitor.log_activations)
446+
metrics.update(log_activations)
447+
log_activations = {}
446448

447449
if world_messenger_hv and num_peers < max_num_peers:
448450
log(message=f"Lost a diloco worker, num_peers: {num_peers}, galaxy_size: {config.hv.galaxy_size}")

open_diloco/utils.py

Lines changed: 27 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -10,73 +10,61 @@
1010
import wandb
1111

1212

13-
_FSDP_WRAPPED_MODULE = ["_forward_module.", "_fsdp_wrapped_module."]
13+
_WRAPPED_NAME_TO_REMOVE = ["_forward_module.", "_fsdp_wrapped_module.", "_orig_mod."]
1414

1515

1616
def _remove_fsdp_prefix(name: str) -> str:
17-
for prefix in _FSDP_WRAPPED_MODULE:
17+
for prefix in _WRAPPED_NAME_TO_REMOVE:
1818
if prefix in name:
19-
return name.replace(prefix, "")
19+
name = name.replace(prefix, "")
2020
return name
2121

2222

23+
@torch.compiler.disable()
2324
@torch.no_grad()
2425
def log_activations_hook(
2526
_mod: torch.nn.Module,
2627
_inp: torch.Tensor,
2728
outp: torch.Tensor | tuple[torch.Tensor, ...],
2829
mod_name: str,
30+
gradient_accumulation_steps: int,
2931
log_activations: dict[str, float],
3032
) -> None:
3133
if isinstance(outp, tuple):
3234
outp = outp[0]
33-
34-
norm = outp.norm(p=2)
35-
35+
norm = outp.norm(p=2) / gradient_accumulation_steps
3636
name = _remove_fsdp_prefix(mod_name)
37-
3837
if f"activation/{name}" not in log_activations:
3938
log_activations[f"activation/{name}"] = norm
4039
else:
4140
log_activations[f"activation/{name}"] += norm
4241

4342

44-
class ActivationNormMetric:
43+
def register_metrics_hooks(
44+
model: torch.nn.Module,
45+
target_layers: list[str],
46+
log_activations: dict[str, torch.Tensor],
47+
gradient_accumulation_steps: int,
48+
) -> list[RemovableHandle]:
4549
"""
46-
This class is used to monitor the norm of the activation of the target layers.
47-
It attached hook to the forward of each layer that will log the output, and remove them after.
50+
this function take a torch module, a list of layer name and apply a hook function that
51+
monitor the output norm of the layers.
4852
"""
49-
50-
def __init__(self, target_layers: list[str], gradient_accumulation_steps: int):
51-
self.target_layers = target_layers
52-
self.handles: list[RemovableHandle] = []
53-
self._log_activations: dict[str, torch.Tensor] = {}
54-
self.gradient_accumulation_steps = gradient_accumulation_steps
55-
56-
def register_metrics_hooks(self, model: torch.nn.Module):
57-
"""
58-
this function take a torch module, a list of layer name and apply a hook function that
59-
monitor the output norm of the layers.
60-
"""
61-
handles = []
62-
for name, mod in model.named_modules():
63-
for layer in self.target_layers:
64-
if name.endswith(layer):
65-
handle = mod.register_forward_hook(
66-
partial(log_activations_hook, log_activations=self._log_activations, mod_name=name)
53+
handles = []
54+
for name, mod in model.named_modules():
55+
for layer in target_layers:
56+
if name.endswith(layer):
57+
handle = mod.register_forward_hook(
58+
partial(
59+
log_activations_hook,
60+
log_activations=log_activations,
61+
mod_name=name,
62+
gradient_accumulation_steps=gradient_accumulation_steps,
6763
)
68-
handles.append(handle)
69-
break
70-
71-
self.handles = handles
72-
73-
def remove_hooks(self) -> None:
74-
for handle in self.handles:
75-
handle.remove()
64+
)
65+
handles.append(handle)
7666

77-
@property
78-
def log_activations(self) -> dict[str, torch.Tensor]:
79-
return {k: v / self.gradient_accumulation_steps for k, v in self._log_activations.items()}
67+
return handles
8068

8169

8270
def _round_str(x: float):

0 commit comments

Comments
 (0)