@@ -40,51 +40,43 @@ class JaxDataArray(Tidy3dBaseModel):
40
40
description = "Dictionary storing the coordinates, namely ``(direction, f, mode_index)``." ,
41
41
)
42
42
43
- @pd .validator ("coords" , always = True )
44
- def _convert_coords_to_list (cls , val ):
45
- """Convert supplied coordinates to Dict[str, list]."""
46
- return {coord_name : list (coord_list ) for coord_name , coord_list in val .items ()}
47
-
48
43
@pd .validator ("values" , always = True )
49
44
def _convert_values_to_np (cls , val ):
50
45
"""Convert supplied values to numpy if they are list (from file)."""
51
46
if isinstance (val , list ):
52
47
return np .array (val )
53
48
return val
54
49
55
- def __eq__ (self , other ) -> bool :
56
- """Check if two ``JaxDataArray`` instances are equal."""
57
- return jnp .array_equal (self .values , other .values )
58
-
59
- # removed because it was slowing things down.
60
- # @pd.validator("coords", always=True)
61
- # def _coords_match_values(cls, val, values):
62
- # """Make sure the coordinate dimensions and shapes match the values data."""
63
-
64
- # values = values.get("values")
65
-
66
- # # if values did not pass validation, just skip this validator
67
- # if values is None:
68
- # return None
69
-
70
- # # compute the shape, otherwise exit.
71
- # try:
72
- # shape = jnp.array(values).shape
73
- # except TypeError:
74
- # return val
50
+ @pd .validator ("coords" , always = True )
51
+ def _coords_match_values (cls , val , values ):
52
+ """Make sure the coordinate dimensions and shapes match the values data."""
53
+
54
+ _values = values .get ("values" )
55
+
56
+ # get the shape, handling both regular and jax objects
57
+ try :
58
+ values_shape = np .array (_values ).shape
59
+ except TypeError :
60
+ values_shape = jnp .array (_values ).shape
61
+
62
+ for (key , coord_val ), size_dim in zip (val .items (), values_shape ):
63
+ if len (coord_val ) != size_dim :
64
+ raise ValueError (
65
+ f"JaxDataArray coord { key } has { len (coord_val )} elements, "
66
+ "which doesn't match the values array "
67
+ f"with size { size_dim } along that dimension."
68
+ )
75
69
76
- # if len(shape) != len(val):
77
- # raise AdjointError(f"'values' has '{len(shape)}' dims, but given '{len(val)}'.")
70
+ return val
78
71
79
- # # make sure each coordinate list has same length as values along that axis
80
- # for len_dim, (coord_name, coord_list) in zip(shape, val.items()):
81
- # if len_dim != len(coord_list):
82
- # raise AdjointError(
83
- # f"coordinate '{coord_name}' has '{len(coord_list)}' elements, "
84
- # f"expected '{len_dim}' to match number of 'values' along this dimension."
85
- # )
72
+ @pd .validator ("coords" , always = True )
73
+ def _convert_coords_to_list (cls , val ):
74
+ """Convert supplied coordinates to Dict[str, list]."""
75
+ return {coord_name : list (coord_list ) for coord_name , coord_list in val .items ()}
86
76
87
- # return val
77
+ def __eq__ (self , other ) -> bool :
78
+ """Check if two ``JaxDataArray`` instances are equal."""
79
+ return jnp .array_equal (self .values , other .values )
88
80
89
81
def to_hdf5 (self , fname : str , group_path : str ) -> None :
90
82
"""Save an xr.DataArray to the hdf5 file with a given path to the group."""
@@ -198,11 +190,18 @@ def __mul__(self, other: JaxDataArray) -> JaxDataArray:
198
190
new_values = self .as_jnp_array * other .as_jnp_array
199
191
elif isinstance (other , xr .DataArray ):
200
192
201
- other_values = other .values .reshape (self .values .shape )
202
- new_values = self .as_jnp_array * other_values
193
+ # handle case where other is missing dims present in self
194
+ new_shape = list (self .shape )
195
+ for dim_index , dim in enumerate (self .coords .keys ()):
196
+ if dim not in other .dims :
197
+ other = other .expand_dims (dim = dim )
198
+ new_shape [dim_index ] = 1
203
199
200
+ other_values = other .values .reshape (new_shape )
201
+ new_values = self .as_jnp_array * other_values
204
202
else :
205
203
new_values = self .as_jnp_array * other
204
+
206
205
return self .updated_copy (values = new_values )
207
206
208
207
def __rmul__ (self , other ) -> JaxDataArray :
@@ -265,8 +264,10 @@ def isel_single(self, coord_name: str, coord_index: int) -> JaxDataArray:
265
264
266
265
# if the coord index has more than one item, keep that coordinate
267
266
coord_index = np .array (coord_index )
268
- if coord_index .size > 1 :
269
- new_coords [coord_name ] = coord_index .tolist ()
267
+ if len (coord_index .shape ) >= 1 :
268
+ coord_indices = coord_index .tolist ()
269
+ new_coord_vals = [self .coords [coord_name ][coord_index ] for coord_index in coord_indices ]
270
+ new_coords [coord_name ] = new_coord_vals
270
271
else :
271
272
new_coords .pop (coord_name )
272
273
@@ -306,20 +307,36 @@ def sel(self, indexers: dict = None, method: str = "nearest", **sel_kwargs) -> J
306
307
isel_kwargs = {}
307
308
for coord_name , sel_kwarg in sel_kwargs .items ():
308
309
coord_list = self .get_coord_list (coord_name )
309
- if sel_kwarg not in coord_list :
310
- raise DataError (f"Could not select '{ coord_name } ={ sel_kwarg } ', value not found." )
311
- coord_index = coord_list .index (sel_kwarg )
312
- isel_kwargs [coord_name ] = coord_index
310
+ if isinstance (sel_kwarg , (tuple , list , np .ndarray )):
311
+ sel_kwarg = list (sel_kwarg )
312
+ isel_kwargs [coord_name ] = []
313
+ for _sel_kwarg in sel_kwarg :
314
+ if _sel_kwarg not in coord_list :
315
+ raise DataError (
316
+ f"Could not select '{ coord_name } ={ _sel_kwarg } ', value not found."
317
+ )
318
+ coord_index = coord_list .index (_sel_kwarg )
319
+ isel_kwargs [coord_name ].append (coord_index )
320
+ else :
321
+ if sel_kwarg not in coord_list :
322
+ raise DataError (
323
+ f"Could not select '{ coord_name } ={ sel_kwarg } ', value not found."
324
+ )
325
+ coord_index = coord_list .index (sel_kwarg )
326
+ isel_kwargs [coord_name ] = coord_index
313
327
return self .isel (** isel_kwargs )
314
328
315
329
def assign_coords (self , coords : dict = None , ** coords_kwargs ) -> JaxDataArray :
316
330
"""Assign new coordinates to this object."""
317
331
318
332
update_kwargs = self .coords .copy ()
319
333
320
- update_kwargs .update (coords_kwargs )
334
+ for key , val in coords_kwargs .items ():
335
+ update_kwargs [key ] = val
336
+
321
337
if coords :
322
- update_kwargs .update (coords )
338
+ for key , val in coords .items ():
339
+ update_kwargs [key ] = val
323
340
324
341
update_kwargs = {key : np .array (value ).tolist () for key , value in update_kwargs .items ()}
325
342
return self .updated_copy (coords = update_kwargs )
0 commit comments