Skip to content

Commit f37f88d

Browse files
committed
handle a mix of None and Tensor being passed by dataset
1 parent 161ae1a commit f37f88d

File tree

3 files changed

+31
-7
lines changed

3 files changed

+31
-7
lines changed

alphafold3_pytorch/trainer.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,25 +90,46 @@ def collate_af3_inputs(
9090
for grouped in zip(*inputs):
9191
# if all None, just return None
9292

93-
if not any([*map(exists, grouped)]):
93+
not_none_grouped = [*filter(exists, grouped)]
94+
95+
if len(not_none_grouped) == 0:
9496
outputs.append(None)
9597
continue
9698

99+
# default to empty tensor for any Nones
100+
101+
one_tensor = not_none_grouped[0]
102+
103+
dtype = one_tensor.dtype
104+
ndim = one_tensor.ndim
105+
97106
# use -1 for padding int values, for assuming int are labels - if not, handle within alphafold3
98107

99-
pad_value = int_pad_value if grouped[0].dtype in (torch.int, torch.long) else 0
108+
if dtype in (torch.int, torch.long):
109+
pad_value = int_pad_value
110+
elif dtype == torch.bool:
111+
pad_value = False
112+
else:
113+
pad_value = 0.
100114

101115
# get the max lengths across all dimensions
102116

103-
shapes_as_tensor = torch.stack([Tensor(tuple(g.shape)) for g in grouped], dim = -1)
117+
shapes_as_tensor = torch.stack([Tensor(tuple(g.shape) if exists(g) else ((0,) * ndim)).int() for g in grouped], dim = -1)
104118

105-
max_lengths = shapes_as_tensor.int().amax(dim = -1)
119+
max_lengths = shapes_as_tensor.amax(dim = -1)
120+
121+
default_tensor = torch.full(max_lengths.tolist(), pad_value, dtype = dtype)
106122

107123
# pad across all dimensions
108124

109125
padded_inputs = []
110126

111127
for inp in grouped:
128+
129+
if not exists(inp):
130+
padded_inputs.append(default_tensor)
131+
continue
132+
112133
for dim, max_length in enumerate(max_lengths.tolist()):
113134
inp = pad_at_dim(inp, (0, max_length - inp.shape[dim]), value = pad_value, dim = dim)
114135

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "alphafold3-pytorch"
3-
version = "0.1.7"
3+
version = "0.1.8"
44
description = "Alphafold 3 - Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

tests/test_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
os.environ['TYPECHECK'] = 'True'
33

44
from pathlib import Path
5-
from random import randrange
5+
from random import randrange, random
66

77
import pytest
88
import torch
@@ -45,7 +45,10 @@ def __getitem__(self, idx):
4545
template_mask = torch.ones((2,)).bool()
4646

4747
msa = torch.randn(7, seq_len, 64)
48-
msa_mask = torch.ones((7,)).bool()
48+
49+
msa_mask = None
50+
if random() > 0.5:
51+
msa_mask = torch.ones((7,)).bool()
4952

5053
# required for training, but omitted on inference
5154

0 commit comments

Comments
 (0)