2727from torch .nn .utils .rnn import pad_sequence
2828from torch .utils .data import Dataset
2929from torchtyping import TensorType
30- from typing import Sequence , Union , NamedTuple , Optional , Tuple
30+ from typing import Sequence , Union , NamedTuple
3131from zlib import decompress
3232from .._abc import default_collate_fn_map
3333
@@ -89,8 +89,7 @@ def collate_molecules(batch, *, padding_left: bool = False, collate_fn_map=None)
8989
9090class MoleculeDataset (Dataset ):
9191 def __init__ (self , molecules : Sequence [Union [MoleculeContainer , bytes ]], * ,
92- hydrogens : Optional [Sequence [Sequence [Tuple [int , ...]]]] = None , cls_token : int = 1 ,
93- max_distance : int = 10 , add_cls : bool = True , max_neighbors : int = 14 ,
92+ cls_token : int = 1 , max_distance : int = 10 , add_cls : bool = True , max_neighbors : int = 14 ,
9493 symmetric_attention : bool = True , components_attention : bool = True ,
9594 unpack : bool = False , compressed : bool = True , distance_cutoff = None ):
9695 """
@@ -106,7 +105,6 @@ def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
106105 that code unreachable atoms (e.g. salts).
107106
108107 :param molecules: molecules collection
109- :param hydrogens: shared hydrogen mapping. First element is hydrogen donor, other are acceptors
110108 :param max_distance: set distances greater than cutoff to cutoff value
111109 :param add_cls: add special token at first position
112110 :param max_neighbors: set neighbors count greater than cutoff to cutoff value
@@ -116,10 +114,7 @@ def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
116114 :param compressed: packed molecules are compressed
117115 :param cls_token: idx of cls token
118116 """
119- assert hydrogens is None or len (hydrogens ) == len (molecules ), 'hydrogens and molecules must have the same size'
120-
121117 self .molecules = molecules
122- self .hydrogens = hydrogens
123118 # distance_cutoff is deprecated
124119 self .max_distance = distance_cutoff if distance_cutoff is not None else max_distance
125120 self .add_cls = add_cls
@@ -132,14 +127,6 @@ def __init__(self, molecules: Sequence[Union[MoleculeContainer, bytes]], *,
132127
133128 def __getitem__ (self , item : int ) -> MoleculeDataPoint :
134129 mol = self .molecules [item ]
135-
136- if self .hydrogens is not None :
137- hmap = self .hydrogens [item ]
138- pad = len (hmap )
139- else :
140- hmap = None
141- pad = 0
142-
143130 if self .unpack :
144131 try :
145132 from ._unpack import unpack
@@ -148,22 +135,15 @@ def __getitem__(self, item: int) -> MoleculeDataPoint:
148135 else :
149136 if self .compressed :
150137 mol = decompress (mol )
151- atoms , neighbors , distances , _ , mapping = unpack (mol , self .add_cls , self .symmetric_attention ,
152- self .components_attention , self .max_neighbors ,
153- self .max_distance , pad )
154- if pad :
155- for n , da in enumerate (hmap , - pad ):
156- neighbors [mapping [da [0 ]]] -= 1
157- for m in da :
158- m = mapping [m ]
159- distances [n , m ] = distances [m , n ] = 1
138+ atoms , neighbors , distances , _ = unpack (mol , self .add_cls , self .symmetric_attention ,
139+ self .components_attention , self .max_neighbors ,
140+ self .max_distance )
160141 if self .add_cls and self .cls_token != 1 :
161142 atoms [0 ] = self .cls_token
162143 return MoleculeDataPoint (IntTensor (atoms ), IntTensor (neighbors ), IntTensor (distances ))
163144
164145 nc = self .max_neighbors
165- lp = len (mol ) + pad
166- mapping = {}
146+ lp = len (mol )
167147
168148 if self .add_cls :
169149 lp += 1
@@ -176,7 +156,6 @@ def __getitem__(self, item: int) -> MoleculeDataPoint:
176156 ngb = mol ._bonds # noqa speedup
177157 hgs = mol ._hydrogens # noqa
178158 for i , (n , a ) in enumerate (mol .atoms (), self .add_cls ):
179- mapping [n ] = i
180159 atoms [i ] = a .atomic_number + 2
181160 nb = len (ngb [n ]) + (hgs [n ] or 0 ) # treat bad valence as 0-hydrogen
182161 if nb > nc :
@@ -188,23 +167,7 @@ def __getitem__(self, item: int) -> MoleculeDataPoint:
188167 minimum (distances , self .max_distance + 2 , out = distances )
189168 distances = IntTensor (distances )
190169
191- if pad :
192- atoms [- pad :] = 2 # set explicit hydrogens
193- tmp = eye (lp , dtype = int32 )
194- if self .add_cls :
195- tmp [0 ] = 1 # enable CLS to atom attention
196- tmp [1 :, 0 ] = 1 if self .symmetric_attention else 0 # enable or disable atom to CLS attention
197- tmp [1 :- pad , 1 :- pad ] = distances
198- else :
199- tmp [:- pad , :- pad ] = distances
200- distances = tmp
201-
202- for n , da in enumerate (hmap , - pad ):
203- neighbors [mapping [da [0 ]]] -= 1
204- for m in da :
205- m = mapping [m ]
206- distances [n , m ] = distances [m , n ] = 1
207- elif self .add_cls :
170+ if self .add_cls :
208171 tmp = ones ((lp , lp ), dtype = int32 )
209172 if not self .symmetric_attention :
210173 tmp [1 :, 0 ] = 0 # disable atom to CLS attention
0 commit comments