@@ -242,17 +242,12 @@ def preprocess(self, index_info: IndexInfo, context: IndexHandlerContext) -> Non
242242 index_value = [tileable .index_value , tileable .columns_value ][input_axis ]
243243
244244 # check if chunks have unknown shape
245- check = False
246- if index_value .has_value ():
247- # index_value has value,
248- check = True
249- elif self ._slice_all (index_info .raw_index ):
250- # if slice on all data
251- check = True
252-
253- if check :
254- if any (np .isnan (ns ) for ns in tileable .nsplits [input_axis ]):
255- yield []
245+ if (
246+ not self ._slice_all (index_info .raw_index )
247+ and index_value .has_value ()
248+ and any (np .isnan (ns ) for ns in tileable .nsplits [input_axis ])
249+ ): # pragma: no cover
250+ yield []
256251
257252 def set_chunk_index_info (
258253 cls ,
@@ -297,6 +292,27 @@ def set_chunk_index_info(
297292
298293 chunk_index_info .set (ChunkIndexAxisInfo (** kw ))
299294
295+ def _process_slice_all_index (
296+ self ,
297+ tileable : Tileable ,
298+ index_info : IndexInfo ,
299+ input_axis : int ,
300+ context : IndexHandlerContext ,
301+ ) -> None :
302+ index_to_info = context .chunk_index_to_info .copy ()
303+ for chunk_index , chunk_index_info in index_to_info .items ():
304+ i = chunk_index [input_axis ]
305+ size = tileable .nsplits [input_axis ][i ]
306+ self .set_chunk_index_info (
307+ context ,
308+ index_info ,
309+ chunk_index ,
310+ chunk_index_info ,
311+ i ,
312+ slice (None ),
313+ size ,
314+ )
315+
300316 def _process_has_value_index (
301317 self ,
302318 tileable : Tileable ,
@@ -306,17 +322,14 @@ def _process_has_value_index(
306322 context : IndexHandlerContext ,
307323 ) -> None :
308324 pd_index = index_value .to_pandas ()
309- if self ._slice_all (index_info .raw_index ):
310- slc = slice (None )
311- else :
312- # turn label-based slice into position-based slice
313- start , end = pd_index .slice_locs (
314- index_info .raw_index .start ,
315- index_info .raw_index .stop ,
316- index_info .raw_index .step ,
317- kind = "loc" ,
318- )
319- slc = slice (start , end , index_info .raw_index .step )
325+ # turn label-based slice into position-based slice
326+ start , end = pd_index .slice_locs (
327+ index_info .raw_index .start ,
328+ index_info .raw_index .stop ,
329+ index_info .raw_index .step ,
330+ kind = "loc" ,
331+ )
332+ slc = slice (start , end , index_info .raw_index .step )
320333
321334 cum_nsplit = [0 ] + np .cumsum (tileable .nsplits [index_info .input_axis ]).tolist ()
322335 # split position-based slice into chunk slices
@@ -379,7 +392,9 @@ def process(self, index_info: IndexInfo, context: IndexHandlerContext) -> None:
379392 else :
380393 index_value = [tileable .index_value , tileable .columns_value ][input_axis ]
381394
382- if index_value .has_value () or self ._slice_all (index_info .raw_index ):
395+ if self ._slice_all (index_info .raw_index ):
396+ self ._process_slice_all_index (tileable , index_info , input_axis , context )
397+ elif index_value .has_value ():
383398 self ._process_has_value_index (
384399 tileable , index_info , index_value , input_axis , context
385400 )
@@ -829,10 +844,13 @@ def parse(self, raw_index, context: IndexHandlerContext) -> IndexInfo:
829844 def preprocess (self , index_info : IndexInfo , context : IndexHandlerContext ) -> None :
830845 tileable = context .tileable
831846 op = context .op
832- if has_unknown_shape (tileable ):
833- yield []
834847
835848 input_axis = index_info .input_axis
849+
850+ # check unknown shape
851+ if any (np .isnan (s ) for s in tileable .nsplits [input_axis ]):
852+ yield []
853+
836854 if tileable .ndim == 2 :
837855 index_value = [tileable .index_value , tileable .columns_value ][input_axis ]
838856 else :
0 commit comments