@@ -935,6 +935,8 @@ def forward(
935935 mask : Bool ['b n' ] | None = None ,
936936 msa_mask : Bool ['b s' ] | None = None
937937 ) -> Float ['b n n dp' ]:
938+
939+ dtype = msa .dtype
938940
939941 msa = self .norm (msa )
940942
@@ -945,12 +947,12 @@ def forward(
945947 # maybe masked mean for outer product
946948
947949 if exists (msa_mask ):
948- a = einx .multiply ('b s i d, b s -> b s i d' , a , msa_mask .float ( ))
949- b = einx .multiply ('b s j e, b s -> b s j e' , b , msa_mask .float ( ))
950+ a = einx .multiply ('b s i d, b s -> b s i d' , a , msa_mask .type ( dtype ))
951+ b = einx .multiply ('b s j e, b s -> b s j e' , b , msa_mask .type ( dtype ))
950952
951953 outer_product = einsum (a , b , 'b s i d, b s j e -> b i j d e' )
952954
953- num_msa = reduce (msa_mask .float ( ), '... s -> ...' , 'sum' )
955+ num_msa = reduce (msa_mask .type ( dtype ), '... s -> ...' , 'sum' )
954956
955957 outer_product_mean = einx .divide ('b i j d e, b' , outer_product , num_msa .clamp (min = self .eps ))
956958 else :
@@ -966,7 +968,9 @@ def forward(
966968
967969 if exists (mask ):
968970 mask = to_pairwise_mask (mask )
969- outer_product_mean = einx .multiply ('b i j d, b i j' , outer_product_mean , mask .float ())
971+ outer_product_mean = einx .multiply (
972+ 'b i j d, b i j' , outer_product_mean , mask .type (dtype )
973+ )
970974
971975 pairwise_repr = self .to_pairwise_repr (outer_product_mean )
972976 return pairwise_repr
@@ -1520,6 +1524,7 @@ def forward(
15201524 additional_molecule_feats : Int [f'b n { ADDITIONAL_MOLECULE_FEATS } ' ]
15211525 ) -> Float ['b n n dp' ]:
15221526
1527+ dtype = self .out_embedder .weight .dtype
15231528 device = additional_molecule_feats .device
15241529
15251530 res_idx , token_idx , asym_id , entity_id , sym_id = additional_molecule_feats .unbind (dim = - 1 )
@@ -1554,7 +1559,7 @@ def onehot(x, bins):
15541559 dist_from_bins = einx .subtract ('... i, j -> ... i j' , x , bins )
15551560 indices = dist_from_bins .abs ().min (dim = - 1 , keepdim = True ).indices
15561561 one_hots = F .one_hot (indices .long (), num_classes = len (bins ))
1557- return one_hots .float ( )
1562+ return one_hots .type ( dtype )
15581563
15591564 r_arange = torch .arange (2 * self .r_max + 2 , device = device )
15601565 s_arange = torch .arange (2 * self .s_max + 2 , device = device )
@@ -1675,6 +1680,7 @@ def forward(
16751680 mask : Bool ['b n' ] | None = None ,
16761681 ) -> Float ['b n n dp' ]:
16771682
1683+ dtype = templates .dtype
16781684 num_templates = templates .shape [1 ]
16791685
16801686 pairwise_repr = self .pairwise_to_embed_input (pairwise_repr )
@@ -1714,7 +1720,7 @@ def forward(
17141720 )
17151721
17161722 num = reduce (templates , 'b t i j d -> b i j d' , 'sum' )
1717- den = reduce (template_mask .float ( ), 'b t -> b' , 'sum' )
1723+ den = reduce (template_mask .type ( dtype ), 'b t -> b' , 'sum' )
17181724
17191725 avg_template_repr = einx .divide ('b i j d, b -> b i j d' , num , den .clamp (min = self .eps ))
17201726
@@ -2612,6 +2618,10 @@ def __init__(
26122618 def device (self ):
26132619 return next (self .net .parameters ()).device
26142620
2621+ @property
2622+ def dtype (self ):
2623+ return next (self .net .parameters ()).dtype
2624+
26152625 # derived preconditioning params - Table 1
26162626
26172627 def c_skip (self , sigma ):
@@ -2637,10 +2647,14 @@ def preconditioned_network_forward(
26372647 network_condition_kwargs : dict ,
26382648 clamp = False ,
26392649 ):
2640- batch , device = noised_atom_pos .shape [0 ], noised_atom_pos .device
2650+ batch , dtype , device = (
2651+ noised_atom_pos .shape [0 ],
2652+ noised_atom_pos .dtype ,
2653+ noised_atom_pos .device ,
2654+ )
26412655
26422656 if isinstance (sigma , float ):
2643- sigma = torch .full ((batch ,), sigma , device = device )
2657+ sigma = torch .full ((batch ,), sigma , dtype = dtype , device = device )
26442658
26452659 padded_sigma = rearrange (sigma , 'b -> b 1 1' )
26462660
@@ -2668,7 +2682,7 @@ def sample_schedule(self, num_sample_steps = None):
26682682 N = num_sample_steps
26692683 inv_rho = 1 / self .rho
26702684
2671- steps = torch .arange (num_sample_steps , device = self .device , dtype = torch . float32 )
2685+ steps = torch .arange (num_sample_steps , device = self .device , dtype = self . dtype )
26722686 sigmas = (self .sigma_max ** inv_rho + steps / (N - 1 ) * (self .sigma_min ** inv_rho - self .sigma_max ** inv_rho )) ** self .rho
26732687
26742688 sigmas = F .pad (sigmas , (0 , 1 ), value = 0. ) # last step is sigma value of 0.
@@ -2687,6 +2701,8 @@ def sample(
26872701 ** network_condition_kwargs
26882702 ) -> Float ['b m 3' ] | Float ['ts b m 3' ]:
26892703
2704+ dtype = self .dtype
2705+
26902706 step_scale , num_sample_steps = self .step_scale , default (num_sample_steps , self .num_sample_steps )
26912707
26922708 shape = (* atom_mask .shape , 3 )
@@ -2709,7 +2725,7 @@ def sample(
27092725
27102726 init_sigma = sigmas [0 ]
27112727
2712- atom_pos = init_sigma * torch .randn (shape , device = self .device )
2728+ atom_pos = init_sigma * torch .randn (shape , dtype = dtype , device = self .device )
27132729
27142730 # gradually denoise
27152731
@@ -2722,9 +2738,11 @@ def sample(
27222738 for sigma , sigma_next , gamma in maybe_tqdm_wrapper (sigmas_and_gammas , desc = tqdm_pbar_title ):
27232739 sigma , sigma_next , gamma = tuple (t .item () for t in (sigma , sigma_next , gamma ))
27242740
2725- atom_pos = maybe_augment_fn (atom_pos )
2741+ atom_pos = maybe_augment_fn (atom_pos . float ()). type ( dtype )
27262742
2727- eps = self .S_noise * torch .randn (shape , device = self .device ) # stochastic sampling
2743+ eps = self .S_noise * torch .randn (
2744+ shape , dtype = dtype , device = self .device
2745+ ) # stochastic sampling
27282746
27292747 sigma_hat = sigma + gamma * sigma
27302748 atom_pos_hat = atom_pos + sqrt (sigma_hat ** 2 - sigma ** 2 ) * eps
@@ -2797,9 +2815,10 @@ def forward(
27972815
27982816 # diffusion loss
27992817
2818+ dtype = atom_pos_ground_truth .dtype
28002819 batch_size = atom_pos_ground_truth .shape [0 ]
28012820
2802- sigmas = self .noise_distribution (batch_size )
2821+ sigmas = self .noise_distribution (batch_size ). type ( dtype )
28032822 padded_sigmas = rearrange (sigmas , 'b -> b 1 1' )
28042823
28052824 noise = torch .randn_like (atom_pos_ground_truth )
@@ -2839,11 +2858,11 @@ def forward(
28392858 )
28402859
28412860 atom_pos_aligned_ground_truth = self .weighted_rigid_align (
2842- pred_coords = denoised_atom_pos ,
2843- true_coords = atom_pos_ground_truth ,
2844- weights = align_weights ,
2861+ pred_coords = denoised_atom_pos . float () ,
2862+ true_coords = atom_pos_ground_truth . float () ,
2863+ weights = align_weights . float () ,
28452864 mask = atom_mask ,
2846- )
2865+ ). type ( dtype )
28472866
28482867 # section 4.2 - multi-chain permutation alignment
28492868
@@ -3375,7 +3394,9 @@ def calculate_optimal_transform(
33753394 selected anchor truth as well as a matrix that records how the atoms should be shifted after applying `r`.
33763395 N.b., Optimal alignment requires 1) a rotation and 2) a shift of the positions.
33773396 """
3397+ dtype = pred_pos .dtype
33783398 batch_size = pred_pos .shape [0 ]
3399+
33793400 input_mask = self .calculate_input_mask (
33803401 true_masks = true_masks ,
33813402 anchor_gt_idx = anchor_gt_idx ,
@@ -3389,13 +3410,13 @@ def calculate_optimal_transform(
33893410 b = batch_size ,
33903411 )
33913412 _ , r , x = self .weighted_rigid_align (
3392- pred_coords = anchor_pred_pos ,
3393- true_coords = anchor_true_pos ,
3413+ pred_coords = anchor_pred_pos . float () ,
3414+ true_coords = anchor_true_pos . float () ,
33943415 mask = input_mask ,
33953416 return_transforms = True ,
33963417 )
33973418
3398- return r , x
3419+ return r . type ( dtype ) , x . type ( dtype )
33993420
34003421 @staticmethod
34013422 @typecheck
@@ -4498,7 +4519,9 @@ def compute_plddt(
44984519 logits = rearrange (logits , "b plddt m -> b m plddt" )
44994520 num_bins = logits .shape [- 1 ]
45004521 bin_width = 1.0 / num_bins
4501- bin_centers = torch .arange (0.5 * bin_width , 1.0 , bin_width , device = logits .device )
4522+ bin_centers = torch .arange (
4523+ 0.5 * bin_width , 1.0 , bin_width , dtype = logits .dtype , device = logits .device
4524+ )
45024525 probs = F .softmax (logits , dim = - 1 )
45034526
45044527 predicted_lddt = einsum (probs , bin_centers , "b m plddt, plddt -> b m" )
@@ -5177,13 +5200,15 @@ def compute_gpde(
51775200 :return: [b] global PDE
51785201 """
51795202
5203+ dtype = pde_logits .dtype
5204+
51805205 pde = self .compute_confidence_score .compute_pde (pde_logits , tok_repr_atm_mask )
51815206
51825207 dist_logits = rearrange (dist_logits , "b dist i j -> b i j dist" )
51835208 dist_probs = F .softmax (dist_logits , dim = - 1 )
51845209
51855210 # for distances greater than the last breaks
5186- dist_breaks = F .pad (dist_breaks , (0 , 1 ), value = 1e6 )
5211+ dist_breaks = F .pad (dist_breaks . float () , (0 , 1 ), value = 1e6 ). type ( dtype )
51875212 contact_mask = dist_breaks < self .contact_mask_threshold
51885213
51895214 contact_prob = einx .where (
@@ -5219,6 +5244,7 @@ def compute_lddt(
52195244 :return: lDDT
52205245 """
52215246
5247+ dtype = pred_coords .dtype
52225248 atom_seq_len , device = pred_coords .shape [1 ], pred_coords .device
52235249
52245250 # Compute distances between all pairs of atoms
@@ -5229,7 +5255,7 @@ def compute_lddt(
52295255 dist_diff = torch .abs (true_dists - pred_dists )
52305256
52315257 lddt = einx .subtract ('thresholds, ... -> ... thresholds' , self .lddt_thresholds , dist_diff )
5232- lddt = (lddt >= 0 ).float ( ).mean (dim = - 1 )
5258+ lddt = (lddt >= 0 ).type ( dtype ).mean (dim = - 1 )
52335259
52345260 # Restrict to bespoke inclusion radius
52355261 is_nucleotide = is_dna | is_rna
@@ -6267,6 +6293,8 @@ def forward(
62676293 atom_seq_len = atom_inputs .shape [- 2 ]
62686294 single_structure_input = atom_inputs .shape [0 ] == 1
62696295
6296+ dtype = atom_inputs .dtype
6297+
62706298 # validate atom and atompair input dimensions
62716299
62726300 assert atom_inputs .shape [- 1 ] == self .dim_atom_inputs , f'expected { self .dim_atom_inputs } for atom_inputs feature dimension, but received { atom_inputs .shape [- 1 ]} '
@@ -6420,7 +6448,7 @@ def forward(
64206448 seq_arange = torch .arange (seq_len , device = self .device )
64216449 token_bonds = einx .subtract ('i, j -> i j' , seq_arange , seq_arange ).abs () == 1
64226450
6423- token_bonds_feats = self .token_bond_to_pairwise_feat (token_bonds .float ( ))
6451+ token_bonds_feats = self .token_bond_to_pairwise_feat (token_bonds .type ( dtype ))
64246452
64256453 pairwise_init = pairwise_init + token_bonds_feats
64266454
@@ -6711,13 +6739,12 @@ def forward(
67116739 fa_atom_mask , aug_atom_mask = atom_mask [:1 ], atom_mask [1 :]
67126740
67136741 fa_atom_pos = self .frame_average (
6714- fa_atom_pos ,
6715- frame_average_mask = fa_atom_mask
6716- )
6742+ fa_atom_pos .float (), frame_average_mask = fa_atom_mask
6743+ ).type (dtype )
67176744
67186745 # normal random augmentations, 48 times in paper
67196746
6720- atom_pos = self .augmenter (atom_pos , mask = aug_atom_mask )
6747+ atom_pos = self .augmenter (atom_pos . float () , mask = aug_atom_mask ). type ( dtype )
67216748
67226749 # concat back the stochastic frame averaged position
67236750
@@ -6793,11 +6820,11 @@ def forward(
67936820
67946821 try :
67956822 atom_pos = self .weighted_rigid_align (
6796- pred_coords = denoised_atom_pos ,
6797- true_coords = atom_pos ,
6798- weights = align_weights ,
6823+ pred_coords = denoised_atom_pos . float () ,
6824+ true_coords = atom_pos . float () ,
6825+ weights = align_weights . float () ,
67996826 mask = atom_mask ,
6800- )
6827+ ). type ( dtype )
68016828 except Exception as e :
68026829 # NOTE: For many (random) unit test inputs, weighted rigid alignment can be unstable
68036830 logger .warning (f"Skipping weighted rigid alignment due to: { e } " )
@@ -7009,7 +7036,7 @@ def forward(
70097036 lddt = einx .subtract (
70107037 "thresholds, ... -> ... thresholds" , self .lddt_thresholds , dist_diff
70117038 )
7012- lddt = (lddt >= 0 ).float ( ).mean (dim = - 1 )
7039+ lddt = (lddt >= 0 ).type ( dtype ).mean (dim = - 1 )
70137040
70147041 # calculate masked averaging,
70157042 # after which we assign each value to one of 50 equally sized bins
@@ -7047,7 +7074,7 @@ def forward(
70477074 else torch .full ((batch_size ,), False , device = self .device )
70487075 )
70497076
7050- confidence_weight = confidence_mask .float ( )
7077+ confidence_weight = confidence_mask .type ( dtype )
70517078
70527079 @typecheck
70537080 def cross_entropy_with_weight (
0 commit comments