@@ -32,7 +32,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
3232 dtypes : Sequence [DType ] | None = None ,
3333 xp : ModuleType | None = None ,
3434 input_indices : Sequence [Sequence [Hashable ]] | None = None ,
35- core_indices : Sequence [Sequence [ Hashable ] ] | None = None ,
35+ core_indices : Sequence [Hashable ] | None = None ,
3636 output_indices : Sequence [Sequence [Hashable ]] | None = None ,
3737 adjust_chunks : Sequence [dict [Hashable , Callable [[int ], int ]]] | None = None ,
3838 new_axes : Sequence [dict [Hashable , int ]] | None = None ,
@@ -70,9 +70,9 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
7070 ndim=3 and 1, `input_indices` could be ``['ijk', 'j']`` or ``[(0, 1, 2),
7171 (1,)]``.
7272 Default: disallow Dask.
73- core_indices : Sequence[Sequence[ Hashable] ], optional
73+ core_indices : Sequence[Hashable], optional
7474 **Dask specific.**
75- Axes labels of each input array that cannot be broken into chunks.
75+ Axes of the input arrays that cannot be broken into chunks.
7676 Default: disallow Dask.
7777 output_indices : Sequence[Sequence[Hashable]], optional
7878 **Dask specific.**
@@ -144,7 +144,7 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
144144
145145 >>> apply_numpy_func(lambda x: x + x.sum(axis=0), x,
146146 ... input_indices=['ij'], output_indices=['ij'],
147- ... core_indices=[ 'i'] )
147+ ... core_indices='i')
148148
149149 This will cause `apply_numpy_func` to raise if the first axis of `x` is broken
150150 along multiple chunks.
@@ -177,9 +177,6 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
177177 if len (input_indices ) != len (args ):
178178 msg = f"got { len (input_indices )} input_indices and { len (args )} args"
179179 raise ValueError (msg )
180- if len (core_indices ) != len (args ):
181- msg = f"got { len (core_indices )} input_indices and { len (args )} args"
182- raise ValueError (msg )
183180 if len (output_indices ) != len (shapes ):
184181 msg = f"got { len (output_indices )} input_indices and { len (shapes )} shapes"
185182 raise NotImplementedError (msg )
@@ -197,19 +194,9 @@ def apply_numpy_func( # type: ignore[no-any-explicit]
197194 raise ValueError (msg )
198195
199196 # core_indices validation
200- for core_idx , inp_idx , arg in zip (
201- core_indices , input_indices , args , strict = True
202- ):
203- for i in core_idx :
204- try :
205- axis = list (inp_idx ).index (i )
206- except ValueError :
207- msg = (
208- f"Index { i } found in core indices but not in "
209- "matching input_indices"
210- )
211- raise ValueError (msg ) from None
212- if len (arg .chunks [axis ]) > 1 :
197+ for inp_idx , arg in zip (input_indices , args , strict = True ):
198+ for i , chunks in zip (inp_idx , arg .chunks , strict = True ):
199+ if i in core_indices and len (chunks ) > 1 :
213200 msg = f"Core index { i } is broken into multiple chunks"
214201 raise ValueError (msg )
215202
0 commit comments