@@ -345,6 +345,17 @@ def mean_pool_with_lens(
345345 lens : Int ['b n' ]
346346) -> Float ['b n d' ]:
347347
348+ summed , mask = sum_pool_with_lens (feats , lens )
349+ avg = einx .divide ('b n d, b n' , summed , lens .clamp (min = 1 ))
350+ avg = einx .where ('b n, b n d, -> b n d' , mask , avg , 0. )
351+ return avg
352+
353+ @typecheck
354+ def sum_pool_with_lens (
355+ feats : Float ['b m d' ],
356+ lens : Int ['b n' ]
357+ ) -> tuple [Float ['b n d' ], Bool ['b n' ]]:
358+
348359 seq_len = feats .shape [1 ]
349360
350361 mask = lens > 0
@@ -364,9 +375,7 @@ def mean_pool_with_lens(
364375 # subtract cumsum at one index from the previous one
365376 summed = sel_cumsum [:, 1 :] - sel_cumsum [:, :- 1 ]
366377
367- avg = einx .divide ('b n d, b n' , summed , lens .clamp (min = 1 ))
368- avg = einx .where ('b n, b n d, -> b n d' , mask , avg , 0. )
369- return avg
378+ return summed , mask
370379
371380@typecheck
372381def mean_pool_fixed_windows_with_mask (
@@ -5271,6 +5280,7 @@ def __init__(
52715280 is_fine_tuning : bool = False ,
52725281 weight_dict_config : dict = None ,
52735282 dssp_path : str = "mkdssp" ,
5283+ use_inhouse_rsa_calculation : bool = False
52745284 ):
52755285 super ().__init__ ()
52765286 self .compute_confidence_score = ComputeConfidenceScore (eps = eps )
@@ -5286,8 +5296,28 @@ def __init__(
52865296
52875297 self .dssp_path = dssp_path
52885298
5299+ self .use_inhouse_rsa_calculation = use_inhouse_rsa_calculation
5300+
5301+ atom_type_radii = tensor ([
5302+ 1.65 , # 0 - nitrogen
5303+ 1.87 , # 1 - carbon alpha
5304+ 1.76 , # 2 - carbon
5305+ 1.4 , # 3 - oxygen
5306+ 1.8 , # 4 - side atoms
5307+ 1.4 # 5 - water
5308+ ])
5309+
5310+ self .atom_type_index = dict (
5311+ N = 0 ,
5312+ CA = 1 ,
5313+ C = 2 ,
5314+ O = 3
5315+ ) # rest go to 4 (side chain atom)
5316+
5317+ self .register_buffer ('atom_radii' , atom_type_radii , persistent = False )
5318+
52895319 @property
5290- def can_calculate_unresolved_protein_rasa (self ):
5320+ def is_mkdssp_available (self ):
52915321 """Check if `mkdssp` is available.
52925322
52935323 :return: True if `mkdssp` is available
@@ -5298,6 +5328,10 @@ def can_calculate_unresolved_protein_rasa(self):
52985328 except sh .ErrorReturnCode_1 :
52995329 return False
53005330
5331+ @property
5332+ def can_calculate_unresolved_protein_rasa (self ):
5333+ return self .is_mkdssp_available or self .use_inhouse_rsa_calculation
5334+
53015335 @typecheck
53025336 def compute_gpde (
53035337 self ,
@@ -5673,6 +5707,155 @@ def _compute_unresolved_rasa(
56735707
56745708 return unresolved_rasa .mean ()
56755709
5710+ @typecheck
5711+ def _inhouse_compute_unresolved_rasa (
5712+ self ,
5713+ unresolved_cid : int ,
5714+ unresolved_residue_mask : Bool [" n" ],
5715+ asym_id : Int [" n" ],
5716+ molecule_ids : Int [" n" ],
5717+ molecule_atom_lens : Int [" n" ],
5718+ atom_pos : Float ["m 3" ],
5719+ atom_mask : Bool [" m" ],
5720+ fibonacci_sphere_n = 200 , # they use 200 in mkdssp, but can be tailored for efficiency
5721+ atom_distance_min_thres = 1e-4
5722+ ) -> Float ["" ]:
5723+ """Compute the unresolved relative solvent accessible surface area (RASA) for proteins.
5724+ using inhouse rebuilt RSA calculation
5725+
5726+ unresolved_cid: asym_id for protein chains with unresolved residues
5727+ unresolved_residue_mask: True for unresolved residues, False for resolved residues
5728+ asym_id: asym_id for each residue
5729+ molecule_ids: molecule_ids for each residue
5730+ molecule_atom_lens: number of atoms for each residue
5731+ atom_pos: [m 3] atom positions
5732+ atom_mask: True for valid atoms, False for missing/padding atoms
5733+ :return: unresolved RASA
5734+ """
5735+
5736+ assert self .can_calculate_unresolved_protein_rasa , "`mkdssp` needs to be installed"
5737+
5738+ num_atom = atom_pos .shape [0 ]
5739+
5740+ chain_mask = asym_id == unresolved_cid
5741+ chain_unresolved_residue_mask = unresolved_residue_mask [chain_mask ]
5742+ chain_asym_id = asym_id [chain_mask ]
5743+ chain_molecule_ids = molecule_ids [chain_mask ]
5744+ chain_molecule_atom_lens = molecule_atom_lens [chain_mask ]
5745+
5746+ chain_mask_to_atom = torch .repeat_interleave (chain_mask , molecule_atom_lens )
5747+
5748+ # if there's padding in num atom
5749+ num_pad = num_atom - molecule_atom_lens .sum ()
5750+ if num_pad > 0 :
5751+ chain_mask_to_atom = F .pad (chain_mask_to_atom , (0 , num_pad ), value = False )
5752+
5753+ chain_atom_pos = atom_pos [chain_mask_to_atom ]
5754+ chain_atom_mask = atom_mask [chain_mask_to_atom ]
5755+
5756+ structure = protein_structure_from_feature (
5757+ chain_asym_id ,
5758+ chain_molecule_ids ,
5759+ chain_molecule_atom_lens ,
5760+ chain_atom_pos ,
5761+ chain_atom_mask ,
5762+ )
5763+
5764+ # use the structure as source of truth, matching what xluo did
5765+
5766+ structure_atom_pos = []
5767+ structure_atom_type_for_radii = []
5768+ side_atom_index = len (self .atom_type_index )
5769+
5770+ for atom in structure .get_atoms ():
5771+
5772+ one_atom_pos = list (atom .get_vector ())
5773+ one_atom_type = self .atom_type_index .get (atom .name , side_atom_index )
5774+
5775+ structure_atom_pos .append (one_atom_pos )
5776+ structure_atom_type_for_radii .append (one_atom_type )
5777+
5778+ structure_atom_pos : Float [' m 3' ] = tensor (structure_atom_pos )
5779+ structure_atom_type_for_radii : Int [' m' ] = tensor (structure_atom_type_for_radii )
5780+
5781+ atom_radii : Float [' m' ] = self .atom_radii [structure_atom_type_for_radii ]
5782+
5783+ water_radii = self .atom_radii [- 1 ]
5784+
5785+ # atom radii is always summed with water radii
5786+
5787+ atom_radii += water_radii
5788+ atom_radii_sq = atom_radii .pow (2 ) # always use square of radii or distance for comparison - save on sqrt
5789+
5790+ # write custom RSA function here
5791+
5792+ # first constitute the fibonacci sphere
5793+
5794+ num_surface_dots = fibonacci_sphere_n * 2 + 1
5795+ golden_ratio = 1. + sqrt (5. ) / 2
5796+ weight = (4. * pi ) / num_surface_dots
5797+
5798+ arange = torch .arange (- fibonacci_sphere_n , fibonacci_sphere_n + 1 ) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5799+
5800+ lat = torch .asin ((2. * arange ) / num_surface_dots )
5801+ lon = torch .fmod (arange , golden_ratio ) * 2 * pi / golden_ratio
5802+
5803+ # ein:
5804+ # sd - surface dots
5805+ # c - coordinate (3)
5806+ # i, j - source and target atom
5807+
5808+ unit_surface_dots : Float ['sd 3' ] = torch .stack ((
5809+ lon .sin () * lat .cos (),
5810+ lon .cos () * lat .cos (),
5811+ lat .sin ()
5812+ ), dim = - 1 )
5813+
5814+ # first get atom relative positions + distance
5815+ # for determining whether to include pairs of atom in calculation for the `free` adjective
5816+
5817+ atom_rel_pos = einx .subtract ('j c, i c -> i j c' , structure_atom_pos , structure_atom_pos )
5818+ atom_rel_dist_sq = atom_rel_pos .pow (2 ).sum (dim = - 1 )
5819+
5820+ max_distance_include = einx .add ('i, j -> i j' , atom_radii , atom_radii ).pow (2 )
5821+
5822+ include_in_free_calc = (
5823+ (atom_rel_dist_sq < max_distance_include ) &
5824+ (atom_rel_dist_sq > atom_distance_min_thres )
5825+ )
5826+
5827+ # overall logic
5828+
5829+ surface_dots = einx .multiply ('m, sd c -> m sd c' , atom_radii , unit_surface_dots )
5830+
5831+ dist_from_surface_dots_sq = einx .subtract ('i j c, i sd c -> i sd j c' , atom_rel_pos , surface_dots ).pow (2 ).sum (dim = - 1 )
5832+
5833+ target_atom_close_to_surface_dots = einx .less ('j, i sd j -> i sd j' , atom_radii_sq , dist_from_surface_dots_sq )
5834+
5835+ target_atom_close_or_not_included = einx .logical_or ('i sd j, i j -> i sd j' , target_atom_close_to_surface_dots , ~ include_in_free_calc )
5836+
5837+ is_free = reduce (target_atom_close_or_not_included , 'i sd j -> i sd' , 'all' ) # basically the most important line, calculating whether an atom is free by some distance measure
5838+
5839+ score = reduce (is_free .float () * weight , 'm sd -> m' , 'sum' )
5840+
5841+ per_atom_accessible_surface_score = score * atom_radii_sq
5842+
5843+ # sum up all surface scores for atoms per residue
5844+ # the final score seems to be the average of the rsa across all residues (selected by `chain_unresolved_residue_mask`)
5845+
5846+ rasa , mask = sum_pool_with_lens (
5847+ rearrange (per_atom_accessible_surface_score , '... -> 1 ... 1' ),
5848+ rearrange (chain_molecule_atom_lens , '... -> 1 ...' )
5849+ )
5850+
5851+ rasa = einx .where ('b n, b n d, -> b n d' , mask , rasa , 0. )
5852+
5853+ rasa = rearrange (rasa , '1 n 1 -> n' )
5854+
5855+ unresolved_rasa = rasa [chain_unresolved_residue_mask ]
5856+
5857+ return unresolved_rasa .mean ()
5858+
56765859 @typecheck
56775860 def compute_unresolved_rasa (
56785861 self ,
@@ -5707,8 +5890,16 @@ def compute_unresolved_rasa(
57075890 weight = weight_dict .get ("unresolved" , {}).get ("unresolved" , None )
57085891 assert weight , "Weight not found for unresolved"
57095892
5893+ # for migrating to a rewritten computation of RSA
5894+ # to remove mkdssp dependency for model selection
5895+
5896+ if self .use_inhouse_rsa_calculation :
5897+ compute_unresolved_rasa_function = self ._inhouse_compute_unresolved_rasa
5898+ else :
5899+ compute_unresolved_rasa_function = self ._compute_unresolved_rasa
5900+
57105901 unresolved_rasa = [
5711- self . _compute_unresolved_rasa (* args )
5902+ compute_unresolved_rasa_function (* args )
57125903 for args in zip (
57135904 unresolved_cid ,
57145905 unresolved_residue_mask ,
0 commit comments