Skip to content

Commit c501fd9

Browse files
committed
first pass
1 parent b47eee7 commit c501fd9

File tree

5 files changed

+353
-2
lines changed

5 files changed

+353
-2
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# This workflow will upload a Python Package using Twine when a release is created
2+
# For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
3+
4+
# This workflow uses actions that are not certified by GitHub.
5+
# They are provided by a third-party and are governed by
6+
# separate terms of service, privacy policy, and support
7+
# documentation.
8+
9+
name: Upload Python Package
10+
11+
on:
12+
release:
13+
types: [published]
14+
15+
jobs:
16+
deploy:
17+
18+
runs-on: ubuntu-latest
19+
20+
steps:
21+
- uses: actions/checkout@v2
22+
- name: Set up Python
23+
uses: actions/setup-python@v2
24+
with:
25+
python-version: '3.x'
26+
- name: Install dependencies
27+
run: |
28+
python -m pip install --upgrade pip
29+
pip install build
30+
- name: Build package
31+
run: python -m build
32+
- name: Publish package
33+
uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
34+
with:
35+
user: __token__
36+
password: ${{ secrets.PYPI_API_TOKEN }}

GAF_microbatch_pytorch/GAF.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
from __future__ import annotations
2+
3+
from functools import partial
4+
from typing import Literal, Callable
5+
6+
import torch
7+
from torch import nn
8+
from torch.nn import Module, Linear
9+
from torch.autograd import Function
10+
import torch.nn.functional as F
11+
12+
from torch.utils._pytree import tree_flatten, tree_unflatten
13+
from torch.func import functional_call, vjp, vmap
14+
15+
from einops import einsum, rearrange, repeat, reduce
16+
17+
# helper functions
18+
19+
def exists(v):
20+
return v is not None
21+
22+
def default(v, d):
23+
return v if exists(v) else d
24+
25+
# distance used for gradient agreement
26+
# they found cosine distance to work the best, at a threshold of ~0.96
27+
28+
def l2norm(t):
29+
return F.normalize(t, p = 2, dim = -1)
30+
31+
def cosine_sim_distance(grads):
32+
grads = rearrange(grads, 'b ... -> b (...)')
33+
normed = l2norm(grads)
34+
dist = einsum(normed, normed, 'i d, j d -> i j')
35+
return 1. - dist
36+
37+
def filter_gradients_by_agreement(
38+
grads,
39+
threshold,
40+
strategy: Literal[
41+
'accept_max_neighbors',
42+
'accept_min_neighbors'
43+
] = 'accept_max_neighbors',
44+
accept_batch_frac = 0.2
45+
):
46+
""" main gradient filtering function """
47+
48+
batch = grads.shape[0]
49+
50+
dist = cosine_sim_distance(grads) # (batch, batch) cosine sim gradient distance
51+
52+
accept_mask = dist < threshold
53+
54+
num_neighbors_within_dist = accept_mask.sum(dim = -1)
55+
56+
if (num_neighbors_within_dist == 1).all():
57+
return torch.zeros_like(grads)
58+
59+
# take the most naive approach
60+
61+
if strategy == 'accept_max_neighbors':
62+
# accept the gradient and its neighbors that is the majority
63+
64+
center_ind = num_neighbors_within_dist.argmax(dim = -1)
65+
66+
accept_mask = accept_mask[center_ind]
67+
68+
elif strategy == 'accept_min_neighbors':
69+
# reject any gradients that does not have at least `batch * accept_batch_frac` similar gradients within the same batch
70+
71+
accept_mask = num_neighbors_within_dist >= max(batch * accept_batch_frac, 2)
72+
else:
73+
raise ValueError(f'unknown strategy {strategy}')
74+
75+
if not accept_mask.any():
76+
return torch.zeros_like(grads)
77+
78+
if accept_mask.all():
79+
return grads
80+
81+
renorm_scale = batch / accept_mask.sum().item()
82+
83+
# filter out the gradients
84+
85+
grads[~accept_mask] = 0.
86+
87+
# renormalize based on how many accepted
88+
89+
grads *= renorm_scale
90+
91+
return grads
92+
93+
# custom linear
94+
95+
class GAF(Function):
96+
97+
@classmethod
98+
def forward(self, ctx, tree_spec, *tree_nodes):
99+
100+
package = tree_unflatten(tree_nodes, tree_spec)
101+
102+
net = package['net']
103+
params, buffers = package['params_buffers']
104+
filter_gradients_fn = package['filter_gradients_fn']
105+
inp_tensor, args, kwargs = package['inputs']
106+
107+
batch = inp_tensor.shape[0]
108+
109+
def fn(params, buffers, inp_tensor):
110+
return functional_call(net, (params, buffers), (inp_tensor, *args), kwargs)
111+
112+
fn = vmap(fn, in_dims = (0, None, 0))
113+
114+
params = {name: repeat(t, '... -> b ...', b = batch) for name, t in params.items()}
115+
116+
output, vjpfunc = vjp(fn, params, buffers, inp_tensor)
117+
118+
ctx._saved_info_for_backwards = (vjpfunc, filter_gradients_fn, args, kwargs)
119+
return output
120+
121+
@classmethod
122+
def backward(self, ctx, do):
123+
124+
vjp_func, filter_gradients_fn, args, kwargs = ctx._saved_info_for_backwards
125+
126+
dparams, dbuffers, dinp = vjp_func(do)
127+
128+
filtered_dparams = {name: filter_gradients_fn(dparam) for name, dparam in dparams.items()}
129+
130+
package = dict(
131+
net = None,
132+
params_buffers = (filtered_dparams, dbuffers),
133+
inputs = (dinp, None, None)
134+
)
135+
136+
tree_nodes, _ = tree_flatten(package)
137+
138+
output = (None, *tree_nodes)
139+
return output
140+
141+
gaf_function = GAF.apply
142+
143+
# main function
144+
145+
class GAFWrapper(Module):
146+
"""
147+
a wrapper for a neural network that automatically starts filtering all the gradients by their intra-batch agreement - not across machines as in the paper
148+
"""
149+
def __init__(
150+
self,
151+
net: Module,
152+
filter_distance_thres = 0.97,
153+
filter_gradients = True,
154+
filter_gradients_fn: Callable | None = None
155+
):
156+
super().__init__()
157+
158+
self.net = net
159+
160+
# gradient agreement filtering related
161+
162+
self.filter_gradients = filter_gradients
163+
self.filter_distance_thres = filter_distance_thres
164+
165+
if not exists(filter_gradients_fn):
166+
filter_gradients_fn = partial(filter_gradients_by_agreement, threshold = filter_distance_thres)
167+
168+
self.filter_gradients_fn = filter_gradients_fn
169+
170+
def forward(
171+
self,
172+
inp_tensor,
173+
*args,
174+
**kwargs
175+
):
176+
only_one_dim_or_no_batch = inp_tensor.ndim == 1 or inp_tensor.shape[0] == 1
177+
178+
if not self.filter_gradients or only_one_dim_or_no_batch:
179+
return self.net(inp_tensor, *args, **kwargs)
180+
181+
params = dict(self.net.named_parameters())
182+
buffers = dict(self.net.named_buffers())
183+
184+
package = dict(
185+
net = self.net,
186+
params_buffers = (params, buffers),
187+
inputs = (inp_tensor, args, kwargs),
188+
filter_gradients_fn = self.filter_gradients_fn
189+
)
190+
191+
tree_nodes, tree_spec = tree_flatten(package)
192+
193+
out = gaf_function(tree_spec, *tree_nodes)
194+
return out

GAF_microbatch_pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from GAF_microbatch_pytorch.GAF import GAFWrapper

README.md

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,64 @@
1-
# GAF-microbatch-pytorch
2-
Implementation of Gradient Agreement Filtering, from Chaubard et al. of Stanford, but for single machine microbatches, in Pytorch
1+
## Gradient Agreement Filtering (microbatch) - Pytorch
2+
3+
Implementation of [Gradient Agreement Filtering](https://arxiv.org/abs/2412.18052), from Chaubard et al. of Stanford, but for single machine microbatches, in Pytorch.
4+
5+
Whether it is just a means to filter out outlier label noise, or actually has some ties to better generalization, thought it was worth exploring either way.
6+
7+
The official repository that does filtering done for macrobatches is [here](https://github.com/Fchaubard/gradient_agreement_filtering)
8+
9+
## Install
10+
11+
```bash
12+
$ pip install GAF-microbatch-pytorch
13+
```
14+
15+
## Usage
16+
17+
```python
18+
import torch
19+
20+
# mock network
21+
22+
from torch import nn
23+
24+
net = nn.Sequential(
25+
nn.Linear(512, 256),
26+
nn.SiLU(),
27+
nn.Linear(256, 128)
28+
)
29+
30+
# import the gradient agreement filtering (GAF) wrapper
31+
32+
from GAF_microbatch_pytorch import GAFWrapper
33+
34+
# just wrap your neural net
35+
36+
gaf_net = GAFWrapper(
37+
net,
38+
filter_distance_thres = 0.97
39+
)
40+
41+
# your batch of data
42+
43+
x = torch.randn(16, 1024, 512)
44+
45+
# forward and backwards as usual
46+
47+
out = gaf_net(x)
48+
49+
out.sum().backward()
50+
51+
# gradients should be filtered by set threshold comparing per sample gradients within batch, as in paper
52+
53+
```
54+
55+
## Citations
56+
57+
```bibtex
58+
@inproceedings{Chaubard2024BeyondGA,
59+
title = {Beyond Gradient Averaging in Parallel Optimization: Improved Robustness through Gradient Agreement Filtering},
60+
author = {Francois Chaubard and Duncan Eddy and Mykel J. Kochenderfer},
61+
year = {2024},
62+
url = {https://api.semanticscholar.org/CorpusID:274992650}
63+
}
64+
```

pyproject.toml

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
[project]
2+
name = "GAF-microbatch-pytorch"
3+
version = "0.0.1"
4+
description = "Gradient Agreement Filtering"
5+
authors = [
6+
{ name = "Phil Wang", email = "[email protected]" }
7+
]
8+
readme = "README.md"
9+
requires-python = ">= 3.9"
10+
license = { file = "LICENSE" }
11+
keywords = [
12+
'artificial intelligence',
13+
'deep learning',
14+
'label noise',
15+
'gradient filtering'
16+
]
17+
18+
classifiers=[
19+
'Development Status :: 4 - Beta',
20+
'Intended Audience :: Developers',
21+
'Topic :: Scientific/Engineering :: Artificial Intelligence',
22+
'License :: OSI Approved :: MIT License',
23+
'Programming Language :: Python :: 3.9',
24+
]
25+
26+
dependencies = [
27+
"torch>=2.4",
28+
"einops>=0.8.0"
29+
]
30+
31+
[project.urls]
32+
Homepage = "https://pypi.org/project/GAF-microbatch-pytorch/"
33+
Repository = "https://github.com/lucidrains/GAF-microbatch-pytorch"
34+
35+
[project.optional-dependencies]
36+
examples = []
37+
test = [
38+
"pytest"
39+
]
40+
41+
[tool.pytest.ini_options]
42+
pythonpath = [
43+
"."
44+
]
45+
46+
[build-system]
47+
requires = ["hatchling"]
48+
build-backend = "hatchling.build"
49+
50+
[tool.rye]
51+
managed = true
52+
dev-dependencies = []
53+
54+
[tool.hatch.metadata]
55+
allow-direct-references = true
56+
57+
[tool.hatch.build.targets.wheel]
58+
packages = ["GAF_microbatch_pytorch"]

0 commit comments

Comments
 (0)