|
149 | 149 |
|
150 | 150 | LinearNoBias = partial(Linear, bias = False) |
151 | 151 |
|
| 152 | +# always use non reentrant checkpointing |
| 153 | + |
| 154 | +checkpoint = partial(checkpoint, use_reentrant = False) |
| 155 | +checkpoint_sequential = partial(checkpoint_sequential, use_reentrant = False) |
| 156 | + |
152 | 157 | # helper functions |
153 | 158 |
|
154 | 159 | def exists(v): |
@@ -1179,7 +1184,7 @@ def inner(inputs): |
1179 | 1184 | wrapped_layers.append(msa_transition_wrapper(msa_transition)) |
1180 | 1185 | wrapped_layers.append(pairwise_block_wrapper(pairwise_block)) |
1181 | 1186 |
|
1182 | | - pairwise_repr, *_ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False) |
| 1187 | + pairwise_repr, *_ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs) |
1183 | 1188 |
|
1184 | 1189 | return pairwise_repr |
1185 | 1190 |
|
@@ -1399,7 +1404,7 @@ def inner(inputs, *args, **kwargs): |
1399 | 1404 | wrapped_layers.append(pair_bias_attn_wrapper(pair_bias_attn)) |
1400 | 1405 | wrapped_layers.append(single_transition_wrapper(single_transition)) |
1401 | 1406 |
|
1402 | | - single_repr, pairwise_repr, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False) |
| 1407 | + single_repr, pairwise_repr, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs) |
1403 | 1408 |
|
1404 | 1409 | return single_repr, pairwise_repr |
1405 | 1410 |
|
@@ -1618,7 +1623,7 @@ def inner(inputs): |
1618 | 1623 | for block in self.pairformer_stack: |
1619 | 1624 | wrapped_layers.append(block_wrapper(block)) |
1620 | 1625 |
|
1621 | | - templates, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False) |
| 1626 | + templates, _ = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs) |
1622 | 1627 |
|
1623 | 1628 | return templates |
1624 | 1629 |
|
@@ -1858,6 +1863,7 @@ def __init__( |
1858 | 1863 | dim = dim, |
1859 | 1864 | prenorm = True, |
1860 | 1865 | gate_value_heads = True, |
| 1866 | + remove_even_power_dups = True, |
1861 | 1867 | **linear_attn_kwargs |
1862 | 1868 | ) |
1863 | 1869 |
|
@@ -1971,7 +1977,7 @@ def inner(inputs): |
1971 | 1977 | wrapped_layers.append(attn_wrapper(attn)) |
1972 | 1978 | wrapped_layers.append(transition_wrapper(transition)) |
1973 | 1979 |
|
1974 | | - out = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs, use_reentrant = False) |
| 1980 | + out = checkpoint_sequential(wrapped_layers, self.checkpoint_segments, inputs) |
1975 | 1981 |
|
1976 | 1982 | noised_repr, *_ = out |
1977 | 1983 | return noised_repr |
@@ -2439,7 +2445,7 @@ def forward( |
2439 | 2445 | token_transformer = self.token_transformer |
2440 | 2446 |
|
2441 | 2447 | if should_checkpoint(self, tokens, 'checkpoint_token_transformer'): |
2442 | | - token_transformer = partial(checkpoint, token_transformer, use_reentrant = False) |
| 2448 | + token_transformer = partial(checkpoint, token_transformer) |
2443 | 2449 |
|
2444 | 2450 | # token transformer |
2445 | 2451 |
|
@@ -2917,9 +2923,7 @@ def forward( |
2917 | 2923 | mask = mask & paired_coords_mask |
2918 | 2924 |
|
2919 | 2925 | # Calculate masked averaging |
2920 | | - lddt_sum = (eps * mask).sum(dim=(-1, -2)) |
2921 | | - lddt_count = mask.sum(dim=(-1, -2)) |
2922 | | - lddt = lddt_sum / lddt_count.clamp(min=1) |
| 2926 | + lddt = masked_average(eps, mask = mask, dim = (-1, -2), eps = 1) |
2923 | 2927 |
|
2924 | 2928 | return 1. - lddt.mean() |
2925 | 2929 |
|
@@ -4994,6 +4998,7 @@ def __init__( |
4994 | 4998 | distogram_atom_resolution = False, |
4995 | 4999 | checkpoint_input_embedding = False, |
4996 | 5000 | checkpoint_trunk_pairformer = False, |
| 5001 | + checkpoint_distogram_head = True, |
4997 | 5002 | checkpoint_diffusion_token_transformer = False, |
4998 | 5003 | detach_when_recycling = True, |
4999 | 5004 | pdb_training_set=True, |
@@ -5213,6 +5218,7 @@ def __init__( |
5213 | 5218 |
|
5214 | 5219 | self.checkpoint_trunk_pairformer = checkpoint_trunk_pairformer |
5215 | 5220 | self.checkpoint_diffusion_token_transformer = checkpoint_diffusion_token_transformer |
| 5221 | + self.checkpoint_distogram_head = checkpoint_distogram_head |
5216 | 5222 |
|
5217 | 5223 | # loss related |
5218 | 5224 |
|
@@ -5566,7 +5572,7 @@ def forward( |
5566 | 5572 | pairformer = self.pairformer |
5567 | 5573 |
|
5568 | 5574 | if should_checkpoint(self, (single, pairwise), 'checkpoint_trunk_pairformer'): |
5569 | | - pairformer = partial(checkpoint, pairformer, use_reentrant = False) |
| 5575 | + pairformer = partial(checkpoint, pairformer) |
5570 | 5576 |
|
5571 | 5577 | # main attention trunk (pairformer) |
5572 | 5578 |
|
@@ -5693,7 +5699,12 @@ def forward( |
5693 | 5699 |
|
5694 | 5700 | distance_labels = torch.where(distogram_mask, distance_labels, ignore) |
5695 | 5701 |
|
5696 | | - distogram_logits = self.distogram_head( |
| 5702 | + distogram_head_fn = self.distogram_head |
| 5703 | + |
| 5704 | + if should_checkpoint(self, pairwise, 'checkpoint_distogram_head'): |
| 5705 | + distogram_head_fn = partial(checkpoint, distogram_head_fn) |
| 5706 | + |
| 5707 | + distogram_logits = distogram_head_fn( |
5697 | 5708 | pairwise, |
5698 | 5709 | molecule_atom_lens = molecule_atom_lens, |
5699 | 5710 | atom_feats = atom_feats |
|
0 commit comments