@@ -5288,6 +5288,7 @@ def __init__(
52885288 contact_mask_threshold : float = 8.0 ,
52895289 is_fine_tuning : bool = False ,
52905290 weight_dict_config : dict = None ,
5291+ fibonacci_sphere_n = 200 , # more points equal better approximation at cost of compute
52915292 ):
52925293 super ().__init__ ()
52935294 self .compute_confidence_score = ComputeConfidenceScore (eps = eps )
@@ -5301,6 +5302,8 @@ def __init__(
53015302 self .register_buffer ("dist_breaks" , dist_breaks )
53025303 self .register_buffer ('lddt_thresholds' , torch .tensor ([0.5 , 1.0 , 2.0 , 4.0 ]))
53035304
5305+ # for rsa calculation
5306+
53045307 atom_type_radii = tensor ([
53055308 1.65 , # 0 - nitrogen
53065309 1.87 , # 1 - carbon alpha
@@ -5319,6 +5322,31 @@ def __init__(
53195322
53205323 self .register_buffer ('atom_radii' , atom_type_radii , persistent = False )
53215324
5325+ # constitute the fibonacci sphere
5326+
5327+ num_surface_dots = fibonacci_sphere_n * 2 + 1
5328+ golden_ratio = 1. + sqrt (5. ) / 2
5329+ weight = (4. * pi ) / num_surface_dots
5330+
5331+ arange = torch .arange (- fibonacci_sphere_n , fibonacci_sphere_n + 1 , device = self .device ) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5332+
5333+ lat = torch .asin ((2. * arange ) / num_surface_dots )
5334+ lon = torch .fmod (arange , golden_ratio ) * 2 * pi / golden_ratio
5335+
5336+ # ein:
5337+ # sd - surface dots
5338+ # c - coordinate (3)
5339+ # i, j - source and target atom
5340+
5341+ unit_surface_dots : Float ['sd 3' ] = torch .stack ((
5342+ lon .sin () * lat .cos (),
5343+ lon .cos () * lat .cos (),
5344+ lat .sin ()
5345+ ), dim = - 1 )
5346+
5347+ self .register_buffer ('unit_surface_dots' , unit_surface_dots )
5348+ self .surface_weight = weight
5349+
53225350 @property
53235351 def device (self ):
53245352 return self .atom_radii .device
@@ -5651,7 +5679,6 @@ def calc_atom_access_surface_score(
56515679 atom_pos : Float ['m 3' ],
56525680 atom_type : Int ['m' ],
56535681 molecule_atom_lens : Int ['n' ] | None = None ,
5654- fibonacci_sphere_n = 200 , # more points equal better approximation at cost of compute
56555682 atom_distance_min_thres = 1e-4
56565683 ) -> Float ['m' ] | Float ['n' ]:
56575684
@@ -5666,28 +5693,6 @@ def calc_atom_access_surface_score(
56665693
56675694 # write custom RSA function here
56685695
5669- # first constitute the fibonacci sphere
5670-
5671- num_surface_dots = fibonacci_sphere_n * 2 + 1
5672- golden_ratio = 1. + sqrt (5. ) / 2
5673- weight = (4. * pi ) / num_surface_dots
5674-
5675- arange = torch .arange (- fibonacci_sphere_n , fibonacci_sphere_n + 1 , device = self .device ) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5676-
5677- lat = torch .asin ((2. * arange ) / num_surface_dots )
5678- lon = torch .fmod (arange , golden_ratio ) * 2 * pi / golden_ratio
5679-
5680- # ein:
5681- # sd - surface dots
5682- # c - coordinate (3)
5683- # i, j - source and target atom
5684-
5685- unit_surface_dots : Float ['sd 3' ] = torch .stack ((
5686- lon .sin () * lat .cos (),
5687- lon .cos () * lat .cos (),
5688- lat .sin ()
5689- ), dim = - 1 )
5690-
56915696 # get atom relative positions + distance
56925697 # for determining whether to include pairs of atom in calculation for the `free` adjective
56935698
@@ -5715,7 +5720,7 @@ def calc_atom_access_surface_score(
57155720
57165721 # overall logic
57175722
5718- surface_dots = einx .multiply ('m, sd c -> m sd c' , atom_radii , unit_surface_dots )
5723+ surface_dots = einx .multiply ('m, sd c -> m sd c' , atom_radii , self . unit_surface_dots )
57195724
57205725 dist_from_surface_dots_sq = einx .subtract ('i j c, i sd c -> i sd j c' , atom_rel_pos , surface_dots ).pow (2 ).sum (dim = - 1 )
57215726
@@ -5725,7 +5730,7 @@ def calc_atom_access_surface_score(
57255730
57265731 is_free = reduce (target_atom_close_or_not_included , 'i sd j -> i sd' , 'all' ) # basically the most important line, calculating whether an atom is free by some distance measure
57275732
5728- score = reduce (is_free .float () * weight , 'm sd -> m' , 'sum' )
5733+ score = reduce (is_free .float () * self . surface_weight , 'm sd -> m' , 'sum' )
57295734
57305735 per_atom_access_surface_score = score * atom_radii_sq
57315736
0 commit comments