Skip to content

Commit 3949506

Browse files
authored
[Operand] support loc setitem (#3291)
* support loc setitem * check column splits * fix iloc indexes * refine test_ownership_when_scale_in ut * fix process_loc_indexes * lint code * fix loc to iloc * add loc row index test
1 parent 9d9cc6e commit 3949506

File tree

4 files changed

+151
-10
lines changed

4 files changed

+151
-10
lines changed

mars/dataframe/indexing/loc.py

Lines changed: 125 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,24 @@
2020
from pandas.core.dtypes.cast import find_common_type
2121
from pandas.core.indexing import IndexingError
2222

23+
from .iloc import DataFrameIlocSetItem
2324
from ... import opcodes as OperandDef
24-
from ...core import ENTITY_TYPE
25+
from ...core import ENTITY_TYPE, OutputType
2526
from ...core.operand import OperandStage
26-
from ...serialization.serializables import KeyField, ListField
27+
from ...serialization.serializables import KeyField, ListField, AnyField
2728
from ...tensor.datasource import asarray
2829
from ...tensor.utils import calc_sliced_size, filter_inputs
2930
from ...utils import lazy_import, is_full_slice
3031
from ..core import IndexValue, DATAFRAME_TYPE
3132
from ..operands import DataFrameOperand, DataFrameOperandMixin
32-
from ..utils import parse_index
33+
from ..utils import parse_index, is_index_value_identical
3334
from .index_lib import DataFrameLocIndexesHandler
3435

3536

3637
cudf = lazy_import("cudf")
3738

3839

39-
def process_loc_indexes(inp, indexes):
40+
def process_loc_indexes(inp, indexes, fetch_index: bool = True):
4041
ndim = inp.ndim
4142

4243
if not isinstance(indexes, tuple):
@@ -51,7 +52,7 @@ def process_loc_indexes(inp, indexes):
5152
if isinstance(index, (list, np.ndarray, pd.Series, ENTITY_TYPE)):
5253
if not isinstance(index, ENTITY_TYPE):
5354
index = np.asarray(index)
54-
else:
55+
elif fetch_index:
5556
index = asarray(index)
5657
if ax == 1:
5758
# do not support tensor index on axis 1
@@ -116,6 +117,125 @@ def __getitem__(self, indexes):
116117
op = DataFrameLocGetItem(indexes=indexes)
117118
return op(self._obj)
118119

120+
def __setitem__(self, indexes, value):
121+
if not np.isscalar(value):
122+
raise NotImplementedError("Only scalar value is supported to set by loc")
123+
if not isinstance(self._obj, DATAFRAME_TYPE):
124+
raise NotImplementedError("Only DataFrame is supported to set by loc")
125+
indexes = process_loc_indexes(self._obj, indexes, fetch_index=False)
126+
use_iloc, new_indexes = self._use_iloc(indexes)
127+
if use_iloc:
128+
op = DataFrameIlocSetItem(indexes=new_indexes, value=value)
129+
ret = op(self._obj)
130+
self._obj.data = ret.data
131+
else:
132+
other_indices = []
133+
indices_tileable = [
134+
idx
135+
for idx in indexes
136+
if isinstance(idx, ENTITY_TYPE) or other_indices.append(idx)
137+
]
138+
op = DataFramelocSetItem(indexes=other_indices, value=value)
139+
ret = op([self._obj] + indices_tileable)
140+
self._obj.data = ret.data
141+
142+
143+
class DataFramelocSetItem(DataFrameOperand, DataFrameOperandMixin):
144+
_op_type_ = OperandDef.DATAFRAME_ILOC_SETITEM
145+
146+
_indexes = ListField("indexes")
147+
_value = AnyField("value")
148+
149+
def __init__(
150+
self, indexes=None, value=None, gpu=None, sparse=False, output_types=None, **kw
151+
):
152+
super().__init__(
153+
_indexes=indexes,
154+
_value=value,
155+
gpu=gpu,
156+
sparse=sparse,
157+
_output_types=output_types,
158+
**kw,
159+
)
160+
if not self.output_types:
161+
self.output_types = [OutputType.dataframe]
162+
163+
@property
164+
def indexes(self):
165+
return self._indexes
166+
167+
@property
168+
def value(self):
169+
return self._value
170+
171+
def __call__(self, inputs):
172+
df = inputs[0]
173+
return self.new_dataframe(
174+
inputs,
175+
shape=df.shape,
176+
dtypes=df.dtypes,
177+
index_value=df.index_value,
178+
columns_value=df.columns_value,
179+
)
180+
181+
@classmethod
182+
def tile(cls, op):
183+
in_df = op.inputs[0]
184+
out_df = op.outputs[0]
185+
out_chunks = []
186+
if len(op.inputs) > 1:
187+
index_series = op.inputs[1]
188+
is_identical = is_index_value_identical(in_df, index_series)
189+
if not is_identical:
190+
raise NotImplementedError("Only identical index value is supported")
191+
if len(in_df.nsplits[1]) != 1:
192+
raise NotImplementedError("Column-split chunks are not supported")
193+
for target_chunk, index_chunk in zip(in_df.chunks, index_series.chunks):
194+
chunk_op = op.copy().reset_key()
195+
out_chunk = chunk_op.new_chunk(
196+
[target_chunk, index_chunk],
197+
shape=target_chunk.shape,
198+
index=target_chunk.index,
199+
dtypes=target_chunk.dtypes,
200+
index_value=target_chunk.index_value,
201+
columns_value=target_chunk.columns_value,
202+
)
203+
out_chunks.append(out_chunk)
204+
else:
205+
for target_chunk in in_df.chunks:
206+
chunk_op = op.copy().reset_key()
207+
out_chunk = chunk_op.new_chunk(
208+
[target_chunk],
209+
shape=target_chunk.shape,
210+
index=target_chunk.index,
211+
dtypes=target_chunk.dtypes,
212+
index_value=target_chunk.index_value,
213+
columns_value=target_chunk.columns_value,
214+
)
215+
out_chunks.append(out_chunk)
216+
217+
new_op = op.copy()
218+
return new_op.new_dataframes(
219+
op.inputs,
220+
shape=out_df.shape,
221+
dtypes=out_df.dtypes,
222+
index_value=out_df.index_value,
223+
columns_value=out_df.columns_value,
224+
chunks=out_chunks,
225+
nsplits=in_df.nsplits,
226+
)
227+
228+
@classmethod
229+
def execute(cls, ctx, op):
230+
chunk = op.outputs[0]
231+
r = ctx[op.inputs[0].key].copy(deep=True)
232+
if len(op.inputs) > 1:
233+
row_index = ctx[op.inputs[1].key]
234+
r.loc[(row_index,) + tuple(op.indexes)] = op.value
235+
else:
236+
r.loc[tuple(op.indexes)] = op.value
237+
ctx[chunk.key] = r
238+
119239

120240
class DataFrameLocGetItem(DataFrameOperand, DataFrameOperandMixin):
121241
_op_type_ = OperandDef.DATAFRAME_LOC_GETITEM

mars/dataframe/indexing/tests/test_indexing_execution.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,27 @@ def test_sample_execution(setup):
16361636
pd.testing.assert_series_equal(r1.execute().fetch(), r2.execute().fetch())
16371637

16381638

1639+
def test_loc_setitem(setup):
1640+
raw_df = pd.DataFrame({"a": [1, 2, 3, 4, 2, 4, 5, 7, 2, 8, 9], 1: [10] * 11})
1641+
md_data = md.DataFrame(raw_df, chunk_size=3)
1642+
md_data.loc[md_data["a"] <= 4, 1] = "v1"
1643+
pd_data = raw_df.copy(True)
1644+
pd_data.loc[pd_data["a"] <= 4, 1] = "v1"
1645+
pd.testing.assert_frame_equal(md_data.to_pandas(), pd_data)
1646+
1647+
md_data1 = md.DataFrame(raw_df, chunk_size=3)
1648+
md_data1.loc[1:3] = "v2"
1649+
pd_data1 = raw_df.copy(True)
1650+
pd_data1.loc[1:3] = "v2"
1651+
pd.testing.assert_frame_equal(md_data1.to_pandas(), pd_data1)
1652+
1653+
md_data2 = md.DataFrame(raw_df, chunk_size=3)
1654+
md_data2.loc[1:3, 1] = "v2"
1655+
pd_data2 = raw_df.copy(True)
1656+
pd_data2.loc[1:3, 1] = "v2"
1657+
pd.testing.assert_frame_equal(md_data2.to_pandas(), pd_data2)
1658+
1659+
16391660
def test_add_prefix_suffix(setup):
16401661
rs = np.random.RandomState(0)
16411662
raw = pd.DataFrame(rs.rand(10, 4), columns=["A", "B", "C", "D"])

mars/deploy/oscar/tests/test_ray_scheduling.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,9 @@ async def test_ownership_when_scale_in(ray_large_cluster):
247247
supervisor_mem=200 * 1024**2,
248248
config={
249249
"scheduling.autoscale.enabled": True,
250-
"scheduling.autoscale.scheduler_check_interval": 1,
251-
"scheduling.autoscale.scheduler_backlog_timeout": 1,
252-
"scheduling.autoscale.worker_idle_timeout": 10,
250+
"scheduling.autoscale.scheduler_check_interval": 0.1,
251+
"scheduling.autoscale.scheduler_backlog_timeout": 0.5,
252+
"scheduling.autoscale.worker_idle_timeout": 1,
253253
"scheduling.autoscale.min_workers": 1,
254254
"scheduling.autoscale.max_workers": 4,
255255
},
@@ -259,7 +259,7 @@ async def test_ownership_when_scale_in(ray_large_cluster):
259259
uid=AutoscalerActor.default_uid(),
260260
address=client._cluster.supervisor_address,
261261
)
262-
num_chunks, chunk_size = 20, 4
262+
num_chunks, chunk_size = 10, 4
263263
df = md.DataFrame(
264264
mt.random.rand(num_chunks * chunk_size, 4, chunk_size=chunk_size),
265265
columns=list("abcd"),

mars/services/scheduling/supervisor/autoscale.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ async def request_worker(
8282
)
8383
if worker_address:
8484
self._dynamic_workers.add(worker_address)
85-
logger.info(
85+
logger.warning(
8686
"Requested new worker %s in %.4f seconds, current dynamic worker nums is %s",
8787
worker_address,
8888
time.time() - start_time,

0 commit comments

Comments
 (0)