2323 UFuncTypeError = None
2424
2525from ...typing import TileableType , ChunkType , OperandType
26- from ...utils import calc_data_size
26+ from ...utils import calc_data_size , tokenize
2727from ..context import Context
2828from ..mode import is_eager_mode
2929from ..entity import (
3030 OutputType ,
31- TILEABLE_TYPE ,
3231 ExecutableTuple ,
3332 get_chunk_types ,
3433 get_tileable_types ,
@@ -46,13 +45,11 @@ class TileableOperandMixin:
4645 def check_inputs (self , inputs : List [TileableType ]):
4746 if not inputs :
4847 return
48+
49+ from ...dataframe .core import DATAFRAME_TYPE
50+
4951 for inp in inputs :
50- if isinstance (inp , TILEABLE_TYPE ):
51- i = inp .extra_params ["_i" ]
52- if not inp .op .output_types :
53- continue
54- if inp .op .output_types [i ] != OutputType .dataframe :
55- continue
52+ if isinstance (inp , DATAFRAME_TYPE ):
5653 dtypes = getattr (inp , "dtypes" , None )
5754 if dtypes is None :
5855 raise ValueError (
@@ -62,24 +59,25 @@ def check_inputs(self, inputs: List[TileableType]):
6259
6360 @classmethod
6461 def _check_if_gpu (cls , inputs : List [TileableType ]):
65- if (
66- inputs is not None
67- and len (
68- [
69- inp
70- for inp in inputs
71- if inp is not None and getattr (inp , "op" , None ) is not None
72- ]
73- )
74- > 0
75- ):
76- if all (inp .op .gpu is True for inp in inputs ):
77- return True
78- elif all (inp .op .gpu is False for inp in inputs ):
79- return False
62+ if not inputs :
63+ return None
64+ true_num = 0
65+ for inp in inputs :
66+ op = getattr (inp , "op" , None )
67+ if op is None or op .gpu is None :
68+ return None
69+ true_num += int (op .gpu )
70+ if true_num == len (inputs ):
71+ return True
72+ elif true_num == 0 :
73+ return False
74+ return None
75+
76+ def _tokenize_output (self , output_idx : int , ** kw ):
77+ return tokenize (self ._key , output_idx )
8078
8179 def _create_chunk (self , output_idx : int , index : Tuple [int ], ** kw ) -> ChunkType :
82- output_type = kw .pop ("output_type" , self ._get_output_type (output_idx ) )
80+ output_type = kw .pop ("output_type" , None ) or self ._get_output_type (output_idx )
8381 if not output_type :
8482 raise ValueError ("output_type should be specified" )
8583
@@ -92,6 +90,11 @@ def _create_chunk(self, output_idx: int, index: Tuple[int], **kw) -> ChunkType:
9290 if output_type == OutputType .scalar :
9391 # tensor
9492 kw ["order" ] = "C_ORDER"
93+
94+ # key of output chunks may only contain keys for its output ids
95+ if "_key" not in kw :
96+ kw ["_key" ] = self ._tokenize_output (output_idx , ** kw )
97+
9598 data = chunk_data_type (** kw )
9699 return chunk_type (data )
97100
@@ -189,6 +192,7 @@ def _create_tileable(self, output_idx: int, **kw) -> TileableType:
189192
190193 if isinstance (output_type , (list , tuple )):
191194 output_type = output_type [output_idx ]
195+
192196 tileable_type , tileable_data_type = get_tileable_types (output_type )
193197 kw ["_i" ] = output_idx
194198 kw ["op" ] = self
@@ -197,6 +201,11 @@ def _create_tileable(self, output_idx: int, **kw) -> TileableType:
197201 kw ["order" ] = "C_ORDER"
198202
199203 kw = self ._fill_nan_shape (kw )
204+
205+ # key of output chunks may only contain keys for its output ids
206+ if "_key" not in kw :
207+ kw ["_key" ] = self ._tokenize_output (output_idx , ** kw )
208+
200209 data = tileable_data_type (** kw )
201210 return tileable_type (data )
202211
@@ -207,12 +216,11 @@ def _new_tileables(
207216 if output_limit is None :
208217 output_limit = getattr (self , "output_limit" )
209218
210- self .check_inputs (inputs )
211- getattr (self , "_set_inputs" )(inputs )
212- if getattr (self , "gpu" , None ) is None :
219+ self ._set_inputs (inputs )
220+ if self .gpu is None :
213221 self .gpu = self ._check_if_gpu (self ._inputs )
214222 if getattr (self , "_key" , None ) is None :
215- getattr ( self , " _update_key" ) () # update key when inputs are set
223+ self . _update_key () # update key when inputs are set
216224
217225 tileables = []
218226 for j in range (output_limit ):
@@ -222,7 +230,7 @@ def _new_tileables(
222230 tileable = self ._create_tileable (j , ** create_tensor_kw )
223231 tileables .append (tileable )
224232
225- setattr ( self , " outputs" , tileables )
233+ self . outputs = tileables
226234 if len (tileables ) > 1 :
227235 # for each output tileable, hold the reference to the other outputs
228236 # so that either no one or everyone are gc collected
0 commit comments