2020from pandas .core .dtypes .cast import find_common_type
2121from pandas .core .indexing import IndexingError
2222
23+ from .iloc import DataFrameIlocSetItem
2324from ... import opcodes as OperandDef
24- from ...core import ENTITY_TYPE
25+ from ...core import ENTITY_TYPE , OutputType
2526from ...core .operand import OperandStage
26- from ...serialization .serializables import KeyField , ListField
27+ from ...serialization .serializables import KeyField , ListField , AnyField
2728from ...tensor .datasource import asarray
2829from ...tensor .utils import calc_sliced_size , filter_inputs
2930from ...utils import lazy_import , is_full_slice
3031from ..core import IndexValue , DATAFRAME_TYPE
3132from ..operands import DataFrameOperand , DataFrameOperandMixin
32- from ..utils import parse_index
33+ from ..utils import parse_index , is_index_value_identical
3334from .index_lib import DataFrameLocIndexesHandler
3435
3536
3637cudf = lazy_import ("cudf" )
3738
3839
39- def process_loc_indexes (inp , indexes ):
40+ def process_loc_indexes (inp , indexes , fetch_index : bool = True ):
4041 ndim = inp .ndim
4142
4243 if not isinstance (indexes , tuple ):
@@ -51,7 +52,7 @@ def process_loc_indexes(inp, indexes):
5152 if isinstance (index , (list , np .ndarray , pd .Series , ENTITY_TYPE )):
5253 if not isinstance (index , ENTITY_TYPE ):
5354 index = np .asarray (index )
54- else :
55+ elif fetch_index :
5556 index = asarray (index )
5657 if ax == 1 :
5758 # do not support tensor index on axis 1
@@ -116,6 +117,125 @@ def __getitem__(self, indexes):
116117 op = DataFrameLocGetItem (indexes = indexes )
117118 return op (self ._obj )
118119
120+ def __setitem__ (self , indexes , value ):
121+ if not np .isscalar (value ):
122+ raise NotImplementedError ("Only scalar value is supported to set by loc" )
123+ if not isinstance (self ._obj , DATAFRAME_TYPE ):
124+ raise NotImplementedError ("Only DataFrame is supported to set by loc" )
125+ indexes = process_loc_indexes (self ._obj , indexes , fetch_index = False )
126+ use_iloc , new_indexes = self ._use_iloc (indexes )
127+ if use_iloc :
128+ op = DataFrameIlocSetItem (indexes = new_indexes , value = value )
129+ ret = op (self ._obj )
130+ self ._obj .data = ret .data
131+ else :
132+ other_indices = []
133+ indices_tileable = [
134+ idx
135+ for idx in indexes
136+ if isinstance (idx , ENTITY_TYPE ) or other_indices .append (idx )
137+ ]
138+ op = DataFramelocSetItem (indexes = other_indices , value = value )
139+ ret = op ([self ._obj ] + indices_tileable )
140+ self ._obj .data = ret .data
141+
142+
143+ class DataFramelocSetItem (DataFrameOperand , DataFrameOperandMixin ):
144+ _op_type_ = OperandDef .DATAFRAME_ILOC_SETITEM
145+
146+ _indexes = ListField ("indexes" )
147+ _value = AnyField ("value" )
148+
149+ def __init__ (
150+ self , indexes = None , value = None , gpu = None , sparse = False , output_types = None , ** kw
151+ ):
152+ super ().__init__ (
153+ _indexes = indexes ,
154+ _value = value ,
155+ gpu = gpu ,
156+ sparse = sparse ,
157+ _output_types = output_types ,
158+ ** kw ,
159+ )
160+ if not self .output_types :
161+ self .output_types = [OutputType .dataframe ]
162+
163+ @property
164+ def indexes (self ):
165+ return self ._indexes
166+
167+ @property
168+ def value (self ):
169+ return self ._value
170+
171+ def __call__ (self , inputs ):
172+ df = inputs [0 ]
173+ return self .new_dataframe (
174+ inputs ,
175+ shape = df .shape ,
176+ dtypes = df .dtypes ,
177+ index_value = df .index_value ,
178+ columns_value = df .columns_value ,
179+ )
180+
181+ @classmethod
182+ def tile (cls , op ):
183+ in_df = op .inputs [0 ]
184+ out_df = op .outputs [0 ]
185+ out_chunks = []
186+ if len (op .inputs ) > 1 :
187+ index_series = op .inputs [1 ]
188+ is_identical = is_index_value_identical (in_df , index_series )
189+ if not is_identical :
190+ raise NotImplementedError ("Only identical index value is supported" )
191+ if len (in_df .nsplits [1 ]) != 1 :
192+ raise NotImplementedError ("Column-split chunks are not supported" )
193+ for target_chunk , index_chunk in zip (in_df .chunks , index_series .chunks ):
194+ chunk_op = op .copy ().reset_key ()
195+ out_chunk = chunk_op .new_chunk (
196+ [target_chunk , index_chunk ],
197+ shape = target_chunk .shape ,
198+ index = target_chunk .index ,
199+ dtypes = target_chunk .dtypes ,
200+ index_value = target_chunk .index_value ,
201+ columns_value = target_chunk .columns_value ,
202+ )
203+ out_chunks .append (out_chunk )
204+ else :
205+ for target_chunk in in_df .chunks :
206+ chunk_op = op .copy ().reset_key ()
207+ out_chunk = chunk_op .new_chunk (
208+ [target_chunk ],
209+ shape = target_chunk .shape ,
210+ index = target_chunk .index ,
211+ dtypes = target_chunk .dtypes ,
212+ index_value = target_chunk .index_value ,
213+ columns_value = target_chunk .columns_value ,
214+ )
215+ out_chunks .append (out_chunk )
216+
217+ new_op = op .copy ()
218+ return new_op .new_dataframes (
219+ op .inputs ,
220+ shape = out_df .shape ,
221+ dtypes = out_df .dtypes ,
222+ index_value = out_df .index_value ,
223+ columns_value = out_df .columns_value ,
224+ chunks = out_chunks ,
225+ nsplits = in_df .nsplits ,
226+ )
227+
228+ @classmethod
229+ def execute (cls , ctx , op ):
230+ chunk = op .outputs [0 ]
231+ r = ctx [op .inputs [0 ].key ].copy (deep = True )
232+ if len (op .inputs ) > 1 :
233+ row_index = ctx [op .inputs [1 ].key ]
234+ r .loc [(row_index ,) + tuple (op .indexes )] = op .value
235+ else :
236+ r .loc [tuple (op .indexes )] = op .value
237+ ctx [chunk .key ] = r
238+
119239
120240class DataFrameLocGetItem (DataFrameOperand , DataFrameOperandMixin ):
121241 _op_type_ = OperandDef .DATAFRAME_LOC_GETITEM
0 commit comments