Skip to content

Commit 123ce19

Browse files
committed
rename srcfile and standardize/fix bug
1 parent 313fb50 commit 123ce19

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

trainer/train_from_cached.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,9 @@ def main():
238238
else:
239239
pipe.to(device)
240240

241+
if args.gradient_topk:
242+
print("Gradient sparsification(gradient_topk) set to", args.gradient_topk)
243+
241244
if args.gradient_checkpointing:
242245
print("Enabling gradient checkpointing in UNet")
243246
pipe.unet.enable_gradient_checkpointing()
@@ -457,7 +460,7 @@ def main():
457460
lr_sched = accelerator.prepare(lr_sched)
458461

459462
if args.gradient_topk:
460-
from grad_topk import apply_global_topk_gradients
463+
from train_grad_topk import apply_global_topk_gradients
461464

462465
global_step = 0
463466
batch_count = 0

trainer/train_grad_topk.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
import torch
3+
4+
def apply_global_topk_gradients(model, keep_frac: float) -> None:
5+
"""
6+
Apply global top-k gradient sparsification in-place on `model`'s
7+
training gradients. This function assumes `loss.backward()`
8+
has already been called.
9+
Purpose is to reduce "catastrophic forgetting" from overtraining,
10+
and potentially allow more knowledge to be stored
11+
12+
Args:
13+
model: torch.nn.Module with gradients already computed (loss.backward()).
14+
keep_frac: Fraction of gradient entries (by absolute magnitude) to keep.
15+
0 < keep_frac < 1. Values outside this range are a no-op.
16+
"""
17+
if keep_frac <= 0.0 or keep_frac >= 1.0:
18+
return
19+
20+
grads = []
21+
for p in model.parameters():
22+
if p.grad is not None:
23+
grads.append(p.grad.view(-1))
24+
25+
if not grads:
26+
return
27+
28+
flat = torch.cat(grads)
29+
total = flat.numel()
30+
31+
k = int(total * keep_frac)
32+
if k <= 0 or k >= total:
33+
return
34+
35+
abs_flat = flat.abs()
36+
# Keep the largest-k entries by |g| threshold is (N - k)-th smallest
37+
thresh = abs_flat.kthvalue(total - k).values
38+
39+
# Zero out small grads in-place
40+
for p in model.parameters():
41+
if p.grad is None:
42+
continue
43+
g = p.grad
44+
mask = g.abs() >= thresh
45+
g.mul_(mask)

0 commit comments

Comments
 (0)