@@ -2503,7 +2503,7 @@ def __init__(
25032503
25042504 self .atom_repr_to_atompair_feat_cond = nn .Sequential (
25052505 nn .LayerNorm (dim_atom ),
2506- LinearNoBiasThenOuterSum (dim_atom , dim_atompair ),
2506+ LinearNoBias (dim_atom , dim_atompair * 2 ),
25072507 nn .ReLU ()
25082508 )
25092509
@@ -2540,7 +2540,7 @@ def forward(
25402540 self ,
25412541 * ,
25422542 atom_inputs : Float ['b m dai' ],
2543- atompair_inputs : Float ['b m m dapi' ],
2543+ atompair_inputs : Float ['b m m dapi' ] | Float [ 'b nw w1 w2 dapi' ] ,
25442544 atom_mask : Bool ['b m' ],
25452545 additional_residue_feats : Float [f'b n { ADDITIONAL_RESIDUE_FEATS } ' ],
25462546 residue_atom_lens : Int ['b n' ],
@@ -2554,19 +2554,32 @@ def forward(
25542554 atom_feats = self .to_atom_feats (atom_inputs )
25552555 atompair_feats = self .to_atompair_feats (atompair_inputs )
25562556
2557+ # window the atom pair features before passing to atom encoder and decoder
2558+
2559+ is_windowed = atompair_inputs .ndim == 5
2560+
2561+ if not is_windowed :
2562+ atompair_feats = full_pairwise_repr_to_windowed (atompair_feats , window_size = w )
2563+
2564+ # condition atompair with atom repr
2565+
25572566 atom_feats_cond = self .atom_repr_to_atompair_feat_cond (atom_feats )
2558- atompair_feats = atom_feats_cond + atompair_feats
25592567
2560- # window the atom pair features before passing to atom encoder and decoder
2568+ atom_feats_cond = pad_to_multiple (atom_feats_cond , w , dim = 1 )
2569+ atom_feats_cond = rearrange (atom_feats_cond , 'b (nw w) dap -> b nw w dap' , w = w )
25612570
2562- windowed_atompair_feats = full_pairwise_repr_to_windowed (atompair_feats , window_size = w )
2571+ atom_feats_cond_row , atom_feats_cond_col = atom_feats_cond .chunk (2 , dim = - 1 )
2572+ atom_feats_cond_col = concat_neighboring_windows (atom_feats_cond_col , dim_seq = 1 , dim_window = - 2 )
2573+
2574+ atompair_feats = einx .add ('b nw w1 w2 dap, b nw w1 dap' ,atompair_feats , atom_feats_cond_row )
2575+ atompair_feats = einx .add ('b nw w1 w2 dap, b nw w2 dap' ,atompair_feats , atom_feats_cond_col )
25632576
25642577 # initial atom transformer
25652578
25662579 atom_feats = self .atom_transformer (
25672580 atom_feats ,
25682581 single_repr = atom_feats ,
2569- pairwise_repr = windowed_atompair_feats
2582+ pairwise_repr = atompair_feats
25702583 )
25712584
25722585 atompair_feats = self .atompair_feats_mlp (atompair_feats ) + atompair_feats
@@ -3035,7 +3048,7 @@ def forward(
30353048 self ,
30363049 * ,
30373050 atom_inputs : Float ['b m dai' ],
3038- atompair_inputs : Float ['b m m dapi' ],
3051+ atompair_inputs : Float ['b m m dapi' ] | Float [ 'b nw w1 w2 dapi' ] ,
30393052 additional_residue_feats : Float [f'b n { ADDITIONAL_RESIDUE_FEATS } ' ],
30403053 residue_atom_lens : Int ['b n' ],
30413054 atom_mask : Bool ['b m' ] | None = None ,
0 commit comments