@@ -391,7 +391,7 @@ def _masked_array_equal(
391391 array1 : np .ndarray ,
392392 array2 : np .ndarray ,
393393 equal_nan : bool ,
394- ) -> bool :
394+ ) -> np . ndarray :
395395 """Return whether two, possibly masked, arrays are equal."""
396396 mask1 = ma .getmask (array1 )
397397 mask2 = ma .getmask (array2 )
@@ -406,7 +406,9 @@ def _masked_array_equal(
406406 else :
407407 eq = np .array_equal (mask1 , mask2 )
408408
409- if eq :
409+ if not eq :
410+ eqs = np .zeros (array1 .shape , dtype = bool )
411+ else :
410412 # Compare data equality.
411413 if not (mask1 is ma .nomask or mask2 is ma .nomask ):
412414 # Ignore masked data.
@@ -422,50 +424,11 @@ def _masked_array_equal(
422424 else :
423425 ignore |= nanmask
424426
425- # This is faster than using np.array_equal with equal_nan=True.
426427 eqs = ma .getdata (array1 ) == ma .getdata (array2 )
427428 if ignore is not None :
428429 eqs = np .where (ignore , True , eqs )
429- eq = eqs .all ()
430-
431- return eq
432430
433-
434- def _apply_masked_array_equal (
435- blocks1 : list | np .ndarray ,
436- blocks2 : list | np .ndarray ,
437- equal_nan : bool ,
438- ) -> bool :
439- """Return whether two collections of arrays are equal or not.
440-
441- This function is for use with :func:`dask.array.blockwise`.
442-
443- Parameters
444- ----------
445- blocks1 :
446- The collection of arrays representing chunks from the first array. Can
447- be a numpy array or a (nested) list of numpy arrays.
448- blocks2 :
449- The collection of arrays representing chunks from the second array. Can
450- be a numpy array or a (nested) list of numpy arrays.
451- equal_nan :
452- Consder NaN values equal.
453-
454- Returns
455- -------
456- :
457- Whether the two collections are equal or not.
458-
459- """
460- if isinstance (blocks1 , np .ndarray ):
461- eq = _masked_array_equal (blocks1 , blocks2 , equal_nan = equal_nan )
462- else :
463- eq = True
464- for block1 , block2 in zip (blocks1 , blocks2 , strict = True ):
465- eq = _apply_masked_array_equal (block1 , block2 , equal_nan = equal_nan )
466- if not eq :
467- break
468- return eq
431+ return eqs
469432
470433
471434def array_equal (array1 , array2 , withnans : bool = False ) -> bool :
@@ -507,19 +470,22 @@ def normalise_array(array):
507470 eq = array1 .shape == array2 .shape
508471 if eq :
509472 if is_lazy_data (array1 ) or is_lazy_data (array2 ):
473+ # Use a separate map and reduce operation to avoid running out of memory.
474+ ndim = array1 .ndim
475+ indices = tuple (range (ndim ))
510476 eq = da .blockwise (
511- _apply_masked_array_equal ,
512- tuple () ,
477+ _masked_array_equal ,
478+ indices ,
513479 array1 ,
514- tuple ( range ( array1 . ndim )) ,
480+ indices ,
515481 array2 ,
516- tuple ( range ( array2 . ndim )) ,
482+ indices ,
517483 dtype = bool ,
518- meta = np .empty ((0 ,), dtype = bool ),
484+ meta = np .empty ((0 ,) * ndim , dtype = bool ),
519485 equal_nan = withnans ,
520- )
486+ ). all ()
521487 else :
522- eq = _masked_array_equal (array1 , array2 , equal_nan = withnans )
488+ eq = _masked_array_equal (array1 , array2 , equal_nan = withnans ). all ()
523489
524490 return bool (eq )
525491
0 commit comments