@@ -279,17 +279,19 @@ def test_where_one_param(array_fixture):
279279 res = np .sort (res )
280280 nres = np .sort (nres )
281281 np .testing .assert_allclose (res [:], nres )
282+
282283 # Test with getitem
283284 sl = slice (100 )
284285 res = expr .where (a1 )[sl ]
286+ nres = na1 [sl ][ne_evaluate ("na1**2 + na2**2 > 2 * na1 * na2 + 1" )[sl ]]
285287 if len (a1 .shape ) == 1 or a1 .chunks == a1 .shape :
286288 # TODO: fix this, as it seems that is not working well for numexpr?
287289 if blosc2 .IS_WASM :
288290 return
289- np .testing .assert_allclose (res , nres [ sl ] )
291+ np .testing .assert_allclose (res , nres )
290292 else :
291293 # In this case, we cannot compare results, only the length
292- assert len (res ) == len (nres [ sl ] )
294+ assert len (res ) == len (nres )
293295
294296
295297# Test where indirectly via a condition in getitem in a NDArray
@@ -330,25 +332,26 @@ def test_where_getitem(array_fixture):
330332 # Test with partial slice
331333 sl = slice (100 )
332334 res = sa1 [a1 ** 2 + a2 ** 2 > 2 * a1 * a2 + 1 ][sl ]
335+ nres = nsa1 [sl ][ne_evaluate ("na1**2 + na2**2 > 2 * na1 * na2 + 1" )[sl ]]
333336 if len (a1 .shape ) == 1 or a1 .chunks == a1 .shape :
334337 # TODO: fix this, as it seems that is not working well for numexpr?
335338 if blosc2 .IS_WASM :
336339 return
337- np .testing .assert_allclose (res ["a" ], nres [sl ][ "a" ])
338- np .testing .assert_allclose (res ["b" ], nres [sl ][ "b" ])
340+ np .testing .assert_allclose (res ["a" ], nres ["a" ])
341+ np .testing .assert_allclose (res ["b" ], nres ["b" ])
339342 else :
340343 # In this case, we cannot compare results, only the length
341- assert len (res ["a" ]) == len (nres [sl ][ "a" ])
342- assert len (res ["b" ]) == len (nres [sl ][ "b" ])
344+ assert len (res ["a" ]) == len (nres ["a" ])
345+ assert len (res ["b" ]) == len (nres ["b" ])
343346 # string version
344347 res = sa1 ["a**2 + b**2 > 2 * a * b + 1" ][sl ]
345348 if len (a1 .shape ) == 1 or a1 .chunks == a1 .shape :
346- np .testing .assert_allclose (res ["a" ], nres [sl ][ "a" ])
347- np .testing .assert_allclose (res ["b" ], nres [sl ][ "b" ])
349+ np .testing .assert_allclose (res ["a" ], nres ["a" ])
350+ np .testing .assert_allclose (res ["b" ], nres ["b" ])
348351 else :
349352 # We cannot compare the results here, other than the length
350- assert len (res ["a" ]) == len (nres [sl ][ "a" ])
351- assert len (res ["b" ]) == len (nres [sl ][ "b" ])
353+ assert len (res ["a" ]) == len (nres ["a" ])
354+ assert len (res ["b" ]) == len (nres ["b" ])
352355
353356
354357# Test where indirectly via a condition in getitem in a NDField
@@ -631,3 +634,40 @@ def test_col_reduction(reduce_op):
631634 ns = nreduc (nC [nC > 0 ])
632635 np .testing .assert_allclose (s , ns )
633636 np .testing .assert_allclose (s2 , ns )
637+
638+
639+ def test_fields_indexing ():
640+ N = 1000
641+ it = ((- x + 1 , x - 2 , 0.1 * x ) for x in range (N ))
642+ sa = blosc2 .fromiter (
643+ it , dtype = [("A" , "i4" ), ("B" , "f4" ), ("C" , "f8" )], shape = (N ,), urlpath = "sa-1M.b2nd" , mode = "w"
644+ )
645+ expr = sa ["(A < B)" ]
646+ A = sa ["A" ][:]
647+ B = sa ["B" ][:]
648+ C = sa ["C" ][:]
649+ temp = sa [:]
650+ indices = A < B
651+ idx = np .argmax (indices )
652+
653+ # Returns less than 10 elements in general
654+ sliced = expr .compute (slice (0 , 10 ))
655+ gotitem = expr [:10 ]
656+ np .testing .assert_array_equal (sliced [:], gotitem )
657+ np .testing .assert_array_equal (gotitem , temp [:10 ][indices [:10 ]])
658+ # Actually this makes sense since one can understand this as a request to compute on a portion of operands.
659+ # If one desires a portion of the result, one should compute the whole expression and then slice it.
660+ # For a general slice it is quite difficult to simply stop when the desired slice has been obtained. Or
661+ # to try to optimise chunk computation order.
662+
663+ # Get first true element
664+ sliced = expr .compute (idx )
665+ gotitem = expr [idx ]
666+ np .testing .assert_array_equal (sliced [()], gotitem )
667+ np .testing .assert_array_equal (gotitem , temp [idx ])
668+
669+ # Should return void arrays here.
670+ sliced = expr .compute (0 ) # typically gives array of zeros
671+ gotitem = expr [0 ] # gives an error
672+ np .testing .assert_array_equal (sliced [()], gotitem )
673+ np .testing .assert_array_equal (gotitem , temp [0 ])
0 commit comments