@@ -5319,6 +5319,10 @@ def __init__(
53195319
53205320 self .register_buffer ('atom_radii' , atom_type_radii , persistent = False )
53215321
5322+ @property
5323+ def device (self ):
5324+ return self .atom_radii .device
5325+
53225326 @typecheck
53235327 def compute_gpde (
53245328 self ,
@@ -5629,10 +5633,10 @@ def calc_atom_access_surface_score_from_structure(
56295633 structure_atom_pos .append (one_atom_pos )
56305634 structure_atom_type_for_radii .append (one_atom_type )
56315635
5632- structure_atom_pos : Float ['m 3' ] = tensor (structure_atom_pos )
5633- structure_atom_type_for_radii : Int ['m' ] = tensor (structure_atom_type_for_radii )
5636+ structure_atom_pos : Float ['m 3' ] = tensor (structure_atom_pos , device = self . device )
5637+ structure_atom_type_for_radii : Int ['m' ] = tensor (structure_atom_type_for_radii , device = self . device )
56345638
5635- structure_atoms_per_residue : Int ['n' ] = tensor ([len ([* residue .get_atoms ()]) for residue in structure .get_residues ()]).long ()
5639+ structure_atoms_per_residue : Int ['n' ] = tensor ([len ([* residue .get_atoms ()]) for residue in structure .get_residues ()], device = self . device ).long ()
56365640
56375641 return self .calc_atom_access_surface_score (
56385642 atom_pos = structure_atom_pos ,
@@ -5668,7 +5672,7 @@ def calc_atom_access_surface_score(
56685672 golden_ratio = 1. + sqrt (5. ) / 2
56695673 weight = (4. * pi ) / num_surface_dots
56705674
5671- arange = torch .arange (- fibonacci_sphere_n , fibonacci_sphere_n + 1 ) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
5675+ arange = torch .arange (- fibonacci_sphere_n , fibonacci_sphere_n + 1 , device = self . device ) # for example, N = 3 -> [-3, -2, -1, 0, 1, 2, 3]
56725676
56735677 lat = torch .asin ((2. * arange ) / num_surface_dots )
56745678 lon = torch .fmod (arange , golden_ratio ) * 2 * pi / golden_ratio
0 commit comments