Skip to content

Commit 0c38547

Browse files
committed
separate serial from parallel (probably error in paper) in diffusion transformers and add checkpointing for serial
1 parent 2c8328b commit 0c38547

File tree

3 files changed

+158
-33
lines changed

3 files changed

+158
-33
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 139 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,7 +1241,7 @@ def inner(inputs, *args, **kwargs):
12411241
wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn))
12421242
wrapped_layers.append(single_transition_wrapper(single_transition))
12431243

1244-
single_repr, pairwise_repr, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs)
1244+
single_repr, pairwise_repr, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False)
12451245

12461246
return single_repr, pairwise_repr
12471247

@@ -1615,6 +1615,8 @@ def __init__(
16151615
serial = False,
16161616
add_residual = True,
16171617
use_linear_attn = False,
1618+
checkpoint = False,
1619+
checkpoint_segments = 1,
16181620
linear_attn_kwargs = dict(
16191621
heads = 8,
16201622
dim_head = 16
@@ -1689,6 +1691,9 @@ def __init__(
16891691
conditionable_transition
16901692
]))
16911693

1694+
self.checkpoint = checkpoint
1695+
self.checkpoint_segments = checkpoint_segments
1696+
16921697
self.layers = layers
16931698

16941699
self.serial = serial
@@ -1703,7 +1708,7 @@ def __init__(
17031708
self.registers = nn.Parameter(torch.zeros(num_register_tokens, dim))
17041709

17051710
@typecheck
1706-
def forward(
1711+
def to_checkpointed_serial_layers(
17071712
self,
17081713
noised_repr: Float['b n d'],
17091714
*,
@@ -1712,32 +1717,92 @@ def forward(
17121717
mask: Bool['b n'] | None = None,
17131718
windowed_mask: Bool['b nw w (w*2)'] | None = None
17141719
):
1715-
w = self.attn_window_size
1716-
has_windows = exists(w)
17171720

1718-
serial = self.serial
1721+
inputs = (noised_repr, single_repr, pairwise_repr, mask, windowed_mask)
17191722

1720-
# handle windowing
1723+
wrapped_layers = []
17211724

1722-
pairwise_is_windowed = pairwise_repr.ndim == 5
1725+
def efficient_attn_wrapper(fn):
1726+
def inner(inputs):
1727+
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
1728+
noised_repr = fn(noised_repr, mask = mask) + noised_repr
1729+
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
1730+
return inner
17231731

1724-
if has_windows and not pairwise_is_windowed:
1725-
pairwise_repr = full_pairwise_repr_to_windowed(pairwise_repr, window_size = w)
1732+
def attn_wrapper(fn):
1733+
def inner(inputs):
1734+
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
1735+
noised_repr = fn(noised_repr, cond = single_repr, pairwise_repr = pairwise_repr, mask = mask, windowed_mask = windowed_mask) + noised_repr
1736+
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
1737+
return inner
17261738

1727-
# register tokens
1739+
def transition_wrapper(fn):
1740+
def inner(inputs):
1741+
noised_repr, single_repr, pairwise_repr, mask, windowed_mask = inputs
1742+
noised_repr = fn(noised_repr, cond = single_repr) + noised_repr
1743+
return noised_repr, single_repr, pairwise_repr, mask, windowed_mask
1744+
return inner
17281745

1729-
if self.has_registers:
1730-
num_registers = self.num_registers
1731-
registers = repeat(self.registers, 'r d -> b r d', b = noised_repr.shape[0])
1732-
noised_repr, registers_ps = pack((registers, noised_repr), 'b * d')
1746+
for linear_attn, colt5_attn, attn, transition in self.layers:
17331747

1734-
single_repr = F.pad(single_repr, (0, 0, num_registers, 0), value = 0.)
1735-
pairwise_repr = F.pad(pairwise_repr, (0, 0, num_registers, 0, num_registers, 0), value = 0.)
1748+
if exists(linear_attn):
1749+
wrapped_layers.append(efficient_attn_wrapper(linear_attn))
17361750

1737-
if exists(mask):
1738-
mask = F.pad(mask, (num_registers, 0), value = True)
1751+
if exists(colt5_attn):
1752+
wrapped_layers.append(efficient_attn_wrapper(colt5_attn))
17391753

1740-
# main transformer
1754+
wrapped_layers.append(attn_wrapper(attn))
1755+
wrapped_layers.append(transition_wrapper(transition))
1756+
1757+
out = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False)
1758+
1759+
noised_repr, *_ = out
1760+
return noised_repr
1761+
1762+
@typecheck
1763+
def to_serial_layers(
1764+
self,
1765+
noised_repr: Float['b n d'],
1766+
*,
1767+
single_repr: Float['b n ds'],
1768+
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*2) dp'],
1769+
mask: Bool['b n'] | None = None,
1770+
windowed_mask: Bool['b nw w (w*2)'] | None = None
1771+
):
1772+
1773+
for linear_attn, colt5_attn, attn, transition in self.layers:
1774+
1775+
if exists(linear_attn):
1776+
noised_repr = linear_attn(noised_repr, mask = mask) + noised_repr
1777+
1778+
if exists(colt5_attn):
1779+
noised_repr = colt5_attn(noised_repr, mask = mask) + noised_repr
1780+
1781+
noised_repr = attn(
1782+
noised_repr,
1783+
cond = single_repr,
1784+
pairwise_repr = pairwise_repr,
1785+
mask = mask,
1786+
windowed_mask = windowed_mask
1787+
) + noised_repr
1788+
1789+
noised_repr = transition(
1790+
noised_repr,
1791+
cond = single_repr
1792+
) + noised_repr
1793+
1794+
return noised_repr
1795+
1796+
@typecheck
1797+
def to_parallel_layers(
1798+
self,
1799+
noised_repr: Float['b n d'],
1800+
*,
1801+
single_repr: Float['b n ds'],
1802+
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*2) dp'],
1803+
mask: Bool['b n'] | None = None,
1804+
windowed_mask: Bool['b nw w (w*2)'] | None = None
1805+
):
17411806

17421807
for linear_attn, colt5_attn, attn, transition in self.layers:
17431808

@@ -1755,25 +1820,72 @@ def forward(
17551820
windowed_mask = windowed_mask
17561821
)
17571822

1758-
if serial:
1759-
noised_repr = attn_out + noised_repr
1760-
17611823
ff_out = transition(
17621824
noised_repr,
17631825
cond = single_repr
17641826
)
17651827

1766-
if serial:
1767-
noised_repr = ff_out + noised_repr
1768-
17691828
# in the algorithm, they omitted the residual, but it could be an error
17701829
# attn + ff + residual was used in GPT-J and PaLM, but later found to be unstable configuration, so it seems unlikely attn + ff would work
17711830
# but in the case they figured out something we have not, you can use their exact formulation by setting `serial = False` and `add_residual = False`
17721831

17731832
residual = noised_repr if self.add_residual else 0.
17741833

1775-
if not serial:
1776-
noised_repr = ff_out + attn_out + residual
1834+
noised_repr = ff_out + attn_out + residual
1835+
1836+
return noised_repr
1837+
1838+
@typecheck
1839+
def forward(
1840+
self,
1841+
noised_repr: Float['b n d'],
1842+
*,
1843+
single_repr: Float['b n ds'],
1844+
pairwise_repr: Float['b n n dp'] | Float['b nw w (w*2) dp'],
1845+
mask: Bool['b n'] | None = None,
1846+
windowed_mask: Bool['b nw w (w*2)'] | None = None
1847+
):
1848+
w = self.attn_window_size
1849+
has_windows = exists(w)
1850+
1851+
serial = self.serial
1852+
1853+
# handle windowing
1854+
1855+
pairwise_is_windowed = pairwise_repr.ndim == 5
1856+
1857+
if has_windows and not pairwise_is_windowed:
1858+
pairwise_repr = full_pairwise_repr_to_windowed(pairwise_repr, window_size = w)
1859+
1860+
# register tokens
1861+
1862+
if self.has_registers:
1863+
num_registers = self.num_registers
1864+
registers = repeat(self.registers, 'r d -> b r d', b = noised_repr.shape[0])
1865+
noised_repr, registers_ps = pack((registers, noised_repr), 'b * d')
1866+
1867+
single_repr = F.pad(single_repr, (0, 0, num_registers, 0), value = 0.)
1868+
pairwise_repr = F.pad(pairwise_repr, (0, 0, num_registers, 0, num_registers, 0), value = 0.)
1869+
1870+
if exists(mask):
1871+
mask = F.pad(mask, (num_registers, 0), value = True)
1872+
1873+
# main transformer
1874+
1875+
if self.serial and should_checkpoint(self, (noised_repr, single_repr, pairwise_repr)):
1876+
to_layers_fn = self.to_checkpointed_serial_layers
1877+
elif self.serial:
1878+
to_layers_fn = self.to_serial_layers
1879+
else:
1880+
to_layers_fn = self.to_parallel_layers
1881+
1882+
noised_repr = to_layers_fn(
1883+
noised_repr,
1884+
single_repr = single_repr,
1885+
pairwise_repr = pairwise_repr,
1886+
mask = mask,
1887+
windowed_mask = windowed_mask,
1888+
)
17771889

17781890
# splice out registers
17791891

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.2.63"
3+
version = "0.2.64"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_af3.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -181,8 +181,8 @@ def test_pairformer(
181181
recurrent_depth,
182182
enable_attn_softclamp
183183
):
184-
single = torch.randn(2, 16, 384)
185-
pairwise = torch.randn(2, 16, 16, 128)
184+
single = torch.randn(2, 16, 384).requires_grad_()
185+
pairwise = torch.randn(2, 16, 16, 128).requires_grad_()
186186
mask = torch.randint(0, 2, (2, 16)).bool()
187187

188188
pairformer = PairformerStack(
@@ -228,17 +228,26 @@ def test_msa_module():
228228

229229
assert pairwise.shape == pairwise_out.shape
230230

231+
@pytest.mark.parametrize('checkpoint', (False, True))
232+
@pytest.mark.parametrize('serial', (False, True))
231233
@pytest.mark.parametrize('use_linear_attn', (False, True))
232234
@pytest.mark.parametrize('use_colt5_attn', (False, True))
233-
def test_diffusion_transformer(use_linear_attn, use_colt5_attn):
235+
def test_diffusion_transformer(
236+
checkpoint,
237+
serial,
238+
use_linear_attn,
239+
use_colt5_attn
240+
):
234241

235-
single = torch.randn(2, 16, 384)
236-
pairwise = torch.randn(2, 16, 16, 128)
242+
single = torch.randn(2, 16, 384).requires_grad_()
243+
pairwise = torch.randn(2, 16, 16, 128).requires_grad_()
237244
mask = torch.randint(0, 2, (2, 16)).bool()
238245

239246
diffusion_transformer = DiffusionTransformer(
240247
depth = 2,
241248
heads = 16,
249+
serial = serial,
250+
checkpoint = checkpoint,
242251
use_linear_attn = use_linear_attn,
243252
use_colt5_attn = use_colt5_attn
244253
)
@@ -252,6 +261,10 @@ def test_diffusion_transformer(use_linear_attn, use_colt5_attn):
252261

253262
assert single.shape == single_out.shape
254263

264+
if checkpoint:
265+
loss = single_out.sum()
266+
loss.backward()
267+
255268
def test_sequence_local_attn():
256269
atoms = torch.randn(2, 17, 32)
257270
attn_bias = torch.randn(2, 17, 17)

0 commit comments

Comments
 (0)