@@ -167,6 +167,42 @@ def loop(
167167 )
168168
169169
170+ def maybe_convert (a ):
171+ return np .asarray (a ) if isinstance (a , jnp .ndarray ) else a
172+
173+
174+ class TensorDataset (torch .utils .data .Dataset ):
175+ def __init__ (self , tensors , x_transform = None , q_transform = None , a_transform = None ):
176+ self .names = ["x" , "q" , "a" ]
177+ self .data = {
178+ name : torch .as_tensor (np .copy (maybe_convert (t ))) if exists (t ) else None
179+ for name , t in zip (self .names , tensors )
180+ }
181+
182+ self .transforms = {
183+ name : transform if exists (transform ) else None
184+ for name , transform in zip (self .names , [x_transform , q_transform , a_transform ])
185+ }
186+
187+ # Sanity check: all non-None tensors must have same first dimension
188+ lengths = [v .shape [0 ] for v in self .data .values () if v is not None ]
189+ assert len (set (lengths )) == 1 , "All input tensors must have the same length."
190+
191+ def __getitem__ (self , index ):
192+ output = []
193+ for key in self .names :
194+ tensor = self .data .get (key )
195+ if exists (tensor ):
196+ val = tensor [index ]
197+ if self .transforms [key ]:
198+ val = self .transforms [key ](val )
199+ output .append (val )
200+ return tuple (output )
201+
202+ def __len__ (self ):
203+ return next (v .shape [0 ] for v in self .data .values () if v is not None )
204+
205+
170206@jaxtyped (typechecker = typechecker )
171207@dataclass
172208class ScalerDataset :
@@ -220,43 +256,6 @@ class ScalerDataset:
220256 ]
221257
222258
223- def maybe_convert (a ):
224- return np .asarray (a ) if isinstance (a , jnp .ndarray ) else a
225-
226-
227- class TensorDataset (torch .utils .data .Dataset ):
228- def __init__ (self , tensors , x_transform = None , q_transform = None , a_transform = None ):
229- self .names = ["x" , "q" , "a" ]
230- self .data = {
231- name : torch .as_tensor (np .copy (maybe_convert (t ))) if exists (t ) else None
232- for name , t in zip (self .names , tensors )
233- }
234-
235- self .transforms = {
236- name : transform if exists (transform ) else None
237- for name , transform in zip (self .names , [x_transform , q_transform , a_transform ])
238- }
239-
240- # Sanity check: all non-None tensors must have same first dimension
241- lengths = [v .shape [0 ] for v in self .data .values () if v is not None ]
242- assert len (set (lengths )) == 1 , "All input tensors must have the same length."
243-
244- def __getitem__ (self , index ):
245- output = []
246- for key in self .names :
247- tensor = self .data .get (key )
248- if exists (tensor ):
249- val = tensor [index ]
250- if self .transforms [key ]:
251- val = self .transforms [key ](val )
252- val = jnp .asarray (val .numpy ())
253- output .append (val )
254- return tuple (output )
255-
256- def __len__ (self ):
257- return next (v .shape [0 ] for v in self .data .values () if v is not None )
258-
259-
260259@jaxtyped (typechecker = typechecker )
261260def dataset_from_tensors (
262261 X : Float [Array , "n ..." ],
0 commit comments