Skip to content

Commit 98b6d93

Browse files
committed
small cleanup
1 parent e5edfa7 commit 98b6d93

File tree

3 files changed

+8
-6
lines changed

3 files changed

+8
-6
lines changed

alphafold3_pytorch/alphafold3.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,9 @@ def pack_one(t, pattern):
124124
def unpack_one(t, ps, pattern):
125125
return unpack(t, ps, pattern)[0]
126126

127+
def exclusive_cumsum(t, dim = -1):
128+
return t.cumsum(dim = dim) - t
129+
127130
# decorators
128131

129132
def maybe(fn):
@@ -262,8 +265,7 @@ def repeat_consecutive_with_lens(
262265
window_size = mask.shape[-1]
263266
arange = torch.arange(window_size, device = device)
264267

265-
cumsum_len = lens.cumsum(dim = -1)
266-
offsets = F.pad(cumsum_len, (1, -1), value = 0)
268+
offsets = exclusive_cumsum(lens)
267269
indices = einx.add('w, b n -> b n w', arange, offsets)
268270

269271
# create output tensor + a sink position on the very right (index max_len)
@@ -3270,7 +3272,7 @@ def forward(
32703272
# handle offsets for molecule atom indices
32713273

32723274
if exists(molecule_atom_indices):
3273-
molecule_atom_indices = molecule_atom_indices + F.pad(molecule_atom_lens.cumsum(dim = -1), (1, -1), value = 0)
3275+
molecule_atom_indices = molecule_atom_indices + exclusive_cumsum(molecule_atom_lens)
32743276

32753277
# get atom sequence length and molecule sequence length depending on whether using packed atomic seq
32763278

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.57"
3+
version = "0.1.58"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,11 +174,11 @@ def test_trainer():
174174
# saving and loading from trainer
175175

176176
trainer.save('./some/nested/folder2/training.pt', overwrite = True)
177-
trainer.load('./some/nested/folder2/training.pt')
177+
trainer.load('./some/nested/folder2/training.pt', strict = False)
178178

179179
# allow for only loading model, needed for fine-tuning logic
180180

181-
trainer.load('./some/nested/folder2/training.pt', only_model = True)
181+
trainer.load('./some/nested/folder2/training.pt', only_model = True, strict = False)
182182

183183
# also allow for loading Alphafold3 directly from training ckpt
184184

0 commit comments

Comments
 (0)