File tree Expand file tree Collapse file tree 3 files changed +28
-5
lines changed Expand file tree Collapse file tree 3 files changed +28
-5
lines changed Original file line number Diff line number Diff line change @@ -73,7 +73,7 @@ def filter_gradients_by_agreement(
7373 else :
7474 raise ValueError (f'unknown strategy { strategy } ' )
7575
76- if not accept_mask .any () :
76+ if accept_mask .sum (). item () <= 1 :
7777 return torch .zeros_like (grads )
7878
7979 if accept_mask .all ():
Original file line number Diff line number Diff line change 11[project ]
22name = " GAF-microbatch-pytorch"
3- version = " 0.0.4 "
3+ version = " 0.0.5 "
44description = " Gradient Agreement Filtering"
55authors = [
66 {
name =
" Phil Wang" ,
email =
" [email protected] " }
Original file line number Diff line number Diff line change 77
88from GAF_microbatch_pytorch import GAFWrapper , set_filter_gradients_
99
10- def test_gaf ():
10+ def test_unfiltered_gaf ():
1111
1212 net = nn .Sequential (
1313 nn .Linear (512 , 256 ),
@@ -47,7 +47,7 @@ def test_gaf():
4747
4848 gaf_net = GAFWrapper (
4949 deepcopy (net ),
50- filter_distance_thres = 0.
50+ filter_distance_thres = 0.7
5151 )
5252
5353 x = torch .randn (8 , 1024 , 512 )
@@ -65,4 +65,27 @@ def test_gaf():
6565 grad = net [0 ].weight .grad
6666 grad_filtered = gaf_net .net [0 ].weight .grad
6767
68- assert not torch .allclose (grad , grad_filtered , atol = 1e-6 )
68+ assert not (grad_filtered == 0. ).all () and not torch .allclose (grad , grad_filtered , atol = 1e-6 )
69+
70+ def test_all_filtered_gaf ():
71+
72+ net = nn .Sequential (
73+ nn .Linear (512 , 256 ),
74+ nn .SiLU (),
75+ nn .Linear (256 , 128 )
76+ )
77+
78+ gaf_net = GAFWrapper (
79+ deepcopy (net ),
80+ filter_distance_thres = 0.
81+ )
82+
83+ x = torch .randn (8 , 1024 , 512 )
84+ x .requires_grad_ ()
85+
86+ out = gaf_net (x )
87+ out .sum ().backward ()
88+
89+ grad_filtered = gaf_net .net [0 ].weight .grad
90+
91+ assert (grad_filtered == 0. ).all ()
You can’t perform that action at this time.
0 commit comments