File tree Expand file tree Collapse file tree 2 files changed +49
-1
lines changed
Expand file tree Collapse file tree 2 files changed +49
-1
lines changed Original file line number Diff line number Diff line change @@ -238,6 +238,9 @@ def main():
238238 else :
239239 pipe .to (device )
240240
241+ if args .gradient_topk :
242+ print ("Gradient sparsification(gradient_topk) set to" , args .gradient_topk )
243+
241244 if args .gradient_checkpointing :
242245 print ("Enabling gradient checkpointing in UNet" )
243246 pipe .unet .enable_gradient_checkpointing ()
@@ -457,7 +460,7 @@ def main():
457460 lr_sched = accelerator .prepare (lr_sched )
458461
459462 if args .gradient_topk :
460- from grad_topk import apply_global_topk_gradients
463+ from train_grad_topk import apply_global_topk_gradients
461464
462465 global_step = 0
463466 batch_count = 0
Original file line number Diff line number Diff line change 1+
2+ import torch
3+
4+ def apply_global_topk_gradients (model , keep_frac : float ) -> None :
5+ """
6+ Apply global top-k gradient sparsification in-place on `model`'s
7+ training gradients. This function assumes `loss.backward()`
8+ has already been called.
9+ Purpose is to reduce "catastrophic forgetting" from overtraining,
10+ and potentially allow more knowledge to be stored
11+
12+ Args:
13+ model: torch.nn.Module with gradients already computed (loss.backward()).
14+ keep_frac: Fraction of gradient entries (by absolute magnitude) to keep.
15+ 0 < keep_frac < 1. Values outside this range are a no-op.
16+ """
17+ if keep_frac <= 0.0 or keep_frac >= 1.0 :
18+ return
19+
20+ grads = []
21+ for p in model .parameters ():
22+ if p .grad is not None :
23+ grads .append (p .grad .view (- 1 ))
24+
25+ if not grads :
26+ return
27+
28+ flat = torch .cat (grads )
29+ total = flat .numel ()
30+
31+ k = int (total * keep_frac )
32+ if k <= 0 or k >= total :
33+ return
34+
35+ abs_flat = flat .abs ()
36+ # Keep the largest-k entries by |g| threshold is (N - k)-th smallest
37+ thresh = abs_flat .kthvalue (total - k ).values
38+
39+ # Zero out small grads in-place
40+ for p in model .parameters ():
41+ if p .grad is None :
42+ continue
43+ g = p .grad
44+ mask = g .abs () >= thresh
45+ g .mul_ (mask )
You can’t perform that action at this time.
0 commit comments