Skip to content

Commit 5b66b9e

Browse files
committed
fix: test also lazy arrays
1 parent d80dda9 commit 5b66b9e

File tree

3 files changed

+12
-7
lines changed

3 files changed

+12
-7
lines changed

src/mrinufft/_array_compat.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,13 @@ def _convert(_array_to_xp, args, kwargs=None):
321321
elif isinstance(_arg, (tuple, list)):
322322
args[n], _ = _convert(_array_to_xp, _arg)
323323
# objects with attributes that are arrays are also converted
324+
elif hasattr(_arg, "__dict__") and not isinstance:
325+
process_dict_vals, _ = _convert(*_arg.__dict__.values())
326+
for k, v in zip(_arg.__dict__.keys(), process_dict_vals):
327+
try:
328+
setattr(_arg, k, v)
329+
except Exception:
330+
pass
324331
# convert keyworded
325332
if kwargs:
326333
process_kwargs_vals, _ = _convert(_array_to_xp, kwargs.values())

src/mrinufft/extras/field_map.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -270,10 +270,6 @@ def __len__(self) -> int:
270270
"""Get number of interpolators."""
271271
return len(self.C_small)
272272

273-
def __getattr__(self, name):
274-
"""Get other attribute from array."""
275-
return getattr(self.C_small, name)
276-
277273
@property
278274
def shape(self):
279275
"""Overall shape of the lazy array."""

tests/test_offres_exp_approx.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def case_complex2D(self, N=64, b0range=(-300, 300), t2svalue=15.0):
6060
@_param_array_interface_np_cp
6161
@parametrize_with_cases("b0_map, r2s_map, mask", cases=CasesB0maps)
6262
@parametrize("method", ["svd-full", "mti", "mfi"])
63-
@parametrize("L", [40, -1])
64-
def test_b0map_coeff(b0_map, r2s_map, mask, method, L, array_interface):
63+
@parametrize("L, lazy", [(40, True), (-1, True), (40, False)])
64+
def test_b0map_coeff(b0_map, r2s_map, mask, method, L, lazy, array_interface):
6565
"""Test exponential approximation for B0 field only."""
6666
# Generate readout times
6767
Nt = 400
@@ -80,7 +80,7 @@ def test_b0map_coeff(b0_map, r2s_map, mask, method, L, array_interface):
8080
to_interface(tread, array_interface),
8181
to_interface(mask, array_interface),
8282
L=L,
83-
lazy=False,
83+
lazy=lazy,
8484
n_bins=4096,
8585
**kwargs,
8686
)
@@ -92,6 +92,8 @@ def test_b0map_coeff(b0_map, r2s_map, mask, method, L, array_interface):
9292
assert B.shape == (Nt, L)
9393
assert C.shape == (L, *b0_map.shape)
9494

95+
if lazy:
96+
C = np.stack([C[l] for l in range(len(C))])
9597
# Check that the approximation match the full matrix.
9698
B = from_interface(B, array_interface)
9799
C = from_interface(C, array_interface)

0 commit comments

Comments
 (0)