@@ -3071,9 +3071,7 @@ def __init__(
30713071 # residue or nucleotide modifications
30723072
30733073 num_molecule_mods = default (num_molecule_mods , 0 )
3074-
30753074 has_molecule_mod_embeds = num_molecule_mods > 0
3076- self .num_molecule_mods = num_molecule_mods
30773075
30783076 if has_molecule_mod_embeds :
30793077 self .molecule_mod_embeds = nn .Embedding (num_molecule_mods , dim_single )
@@ -3233,6 +3231,11 @@ def __init__(
32333231
32343232 self .register_buffer ('zero' , torch .tensor (0. ), persistent = False )
32353233
3234+ # some shorthand for jaxtyping
3235+
3236+ self .dapi = self .dim_atompair_inputs
3237+ self .dai = self .dim_atom_inputs
3238+
32363239 @property
32373240 def device (self ):
32383241 return self .zero .device
@@ -3309,16 +3312,16 @@ def init_and_load(
33093312 def forward (
33103313 self ,
33113314 * ,
3312- atom_inputs : Float ['b m dai' ],
3313- atompair_inputs : Float ['b m m dapi' ] | Float ['b nw w1 w2 dapi' ],
3315+ atom_inputs : Float ['b m {self. dai} ' ],
3316+ atompair_inputs : Float ['b m m {self. dapi} ' ] | Float ['b nw w1 w2 {self. dapi} ' ],
33143317 additional_molecule_feats : Int [f'b n { ADDITIONAL_MOLECULE_FEATS } ' ],
33153318 is_molecule_types : Bool [f'b n { IS_MOLECULE_TYPES } ' ],
33163319 molecule_atom_lens : Int ['b n' ],
33173320 molecule_ids : Int ['b n' ],
33183321 additional_token_feats : Float ['b n {self.dim_additional_token_feats}' ] | None = None ,
33193322 atom_ids : Int ['b m' ] | None = None ,
33203323 atompair_ids : Int ['b m m' ] | Int ['b nw w1 w2' ] | None = None ,
3321- is_molecule_mod : Bool ['b n {self.num_molecule_mods} ' ] | None = None ,
3324+ is_molecule_mod : Bool ['b n num_mods ' ] | None = None ,
33223325 atom_mask : Bool ['b m' ] | None = None ,
33233326 atom_parent_ids : Int ['b m' ] | None = None ,
33243327 token_bonds : Bool ['b n n' ] | None = None ,
0 commit comments