File tree Expand file tree Collapse file tree 2 files changed +92
-0
lines changed Expand file tree Collapse file tree 2 files changed +92
-0
lines changed Original file line number Diff line number Diff line change 1+ name : Tests the examples in README
2+ on : [push, pull_request]
3+
4+ env :
5+ TYPECHECK : True
6+
7+ jobs :
8+ test :
9+ runs-on : ubuntu-latest
10+ steps :
11+ - uses : actions/checkout@v4
12+ - name : Install Python
13+ uses : actions/setup-python@v5
14+ with :
15+ python-version : " 3.11"
16+ - name : Install dependencies
17+ run : |
18+ python -m pip install uv
19+ python -m uv pip install --upgrade pip
20+ python -m uv pip install torch --index-url https://download.pytorch.org/whl/nightly/cpu
21+ python -m uv pip install -e .[test]
22+ - name : Test with pytest
23+ run : |
24+ python -m pytest tests/
Original file line number Diff line number Diff line change 1+ import pytest
2+ from copy import deepcopy
3+
4+ import torch
5+ from torch import nn
6+ torch .set_default_dtype (torch .float64 )
7+
8+ from GAF_microbatch_pytorch import GAFWrapper , set_filter_gradients_
9+
10+ def test_gaf ():
11+
12+ net = nn .Sequential (
13+ nn .Linear (512 , 256 ),
14+ nn .SiLU (),
15+ nn .Linear (256 , 128 )
16+ )
17+
18+ gaf_net = GAFWrapper (
19+ deepcopy (net ),
20+ filter_distance_thres = 2.
21+ )
22+
23+ x = torch .randn (8 , 1024 , 512 )
24+ y = x .clone ()
25+
26+ x .requires_grad_ ()
27+ y .requires_grad_ ()
28+
29+ out1 = net (x )
30+ out2 = gaf_net (y )
31+
32+ out1 .sum ().backward ()
33+ out2 .sum ().backward ()
34+
35+ grad = net [0 ].weight .grad
36+ grad_filtered = gaf_net .net [0 ].weight .grad
37+
38+ assert torch .allclose (grad , grad_filtered , atol = 1e-6 )
39+
40+ def test_gaf ():
41+
42+ net = nn .Sequential (
43+ nn .Linear (512 , 256 ),
44+ nn .SiLU (),
45+ nn .Linear (256 , 128 )
46+ )
47+
48+ gaf_net = GAFWrapper (
49+ deepcopy (net ),
50+ filter_distance_thres = 0.
51+ )
52+
53+ x = torch .randn (8 , 1024 , 512 )
54+ y = x .clone ()
55+
56+ x .requires_grad_ ()
57+ y .requires_grad_ ()
58+
59+ out1 = net (x )
60+ out2 = gaf_net (y )
61+
62+ out1 .sum ().backward ()
63+ out2 .sum ().backward ()
64+
65+ grad = net [0 ].weight .grad
66+ grad_filtered = gaf_net .net [0 ].weight .grad
67+
68+ assert not torch .allclose (grad , grad_filtered , atol = 1e-6 )
You can’t perform that action at this time.
0 commit comments