Skip to content

Commit 4dabb97

Browse files
authored
Refine ThreadedServiceContext.get_chunks_meta usage (#3037)
1 parent de933b9 commit 4dabb97

File tree

6 files changed

+9
-34
lines changed

6 files changed

+9
-34
lines changed

mars/conftest.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from mars.config import option_context
2424
from mars.core.mode import is_kernel_mode, is_build_mode
25-
from mars.lib.aio import stop_isolation
2625
from mars.oscar.backends.router import Router
2726
from mars.oscar.backends.ray.communication import RayServer
2827
from mars.serialization.ray import register_ray_serializers, unregister_ray_serializers
@@ -160,13 +159,7 @@ async def ray_create_mars_cluster(request):
160159

161160

162161
@pytest.fixture(scope="module")
163-
def _stop_isolation():
164-
yield
165-
stop_isolation()
166-
167-
168-
@pytest.fixture(scope="module")
169-
def _new_test_session(_stop_isolation):
162+
def _new_test_session():
170163
from .deploy.oscar.tests.session import new_test_session
171164

172165
sess = new_test_session(
@@ -184,7 +177,7 @@ def _new_test_session(_stop_isolation):
184177

185178

186179
@pytest.fixture(scope="module")
187-
def _new_integrated_test_session(_stop_isolation):
180+
def _new_integrated_test_session():
188181
from .deploy.oscar.tests.session import new_test_session
189182

190183
sess = new_test_session(
@@ -217,7 +210,7 @@ def _new_integrated_test_session(_stop_isolation):
217210

218211

219212
@pytest.fixture(scope="module")
220-
def _new_gpu_test_session(_stop_isolation): # pragma: no cover
213+
def _new_gpu_test_session(): # pragma: no cover
221214
from .deploy.oscar.tests.session import new_test_session
222215
from .resource import cuda_count
223216

mars/dataframe/base/standardize_range_index.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,13 @@ class ChunkStandardizeRangeIndex(DataFrameOperand, DataFrameOperandMixin):
2929
_op_type_ = OperandDef.STANDARDIZE_RANGE_INDEX
3030

3131
axis = Int32Field("axis")
32-
prev_keys = ListField("prev_keys", FieldTypes.string)
32+
prev_shapes = ListField("prev_shapes", FieldTypes.tuple)
3333

3434
@classmethod
3535
def execute(cls, ctx, op: "ChunkStandardizeRangeIndex"):
3636
xdf = cudf if op.gpu else pd
3737
in_data = ctx[op.inputs[0].key].copy()
38-
metas = ctx.get_chunks_meta(op.prev_keys, fields=["shape"])
39-
index_start = sum([m["shape"][op.axis] for m in metas])
38+
index_start = sum([shape[op.axis] for shape in op.prev_shapes])
4039
if op.axis == 0:
4140
in_data.index = xdf.RangeIndex(index_start, index_start + len(in_data))
4241
else:

mars/dataframe/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1086,7 +1086,7 @@ def standardize_range_index(chunks: List[ChunkType], axis: int = 0):
10861086
for c in chunks:
10871087
prev_chunks = row_chunks[: c.index[axis]]
10881088
op = ChunkStandardizeRangeIndex(
1089-
prev_keys=[p.key for p in prev_chunks], axis=axis
1089+
prev_shapes=[p.shape for p in prev_chunks], axis=axis
10901090
)
10911091
op.output_types = c.op.output_types
10921092
params = c.params.copy()

mars/remote/core.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from functools import partial
1717

1818
from .. import opcodes
19-
from ..core import ENTITY_TYPE, TILEABLE_TYPE, ChunkData
19+
from ..core import ENTITY_TYPE, ChunkData
2020
from ..core.custom_log import redirect_custom_log
2121
from ..core.operand import ObjectOperand
2222
from ..dataframe.core import DATAFRAME_TYPE, SERIES_TYPE, INDEX_TYPE
@@ -33,7 +33,6 @@
3333
enter_current_session,
3434
find_objects,
3535
replace_objects,
36-
get_chunk_params,
3736
)
3837
from .operands import RemoteOperandMixin
3938

@@ -185,17 +184,6 @@ def execute(cls, ctx, op: "RemoteFunction"):
185184
for inp, is_pure_dep in zip(op.inputs, op.pure_depends)
186185
if not is_pure_dep
187186
}
188-
for to_search in [op.function_args, op.function_kwargs]:
189-
tileables = find_objects(to_search, TILEABLE_TYPE)
190-
for tileable in tileables:
191-
chunks = tileable.chunks
192-
fields = get_chunk_params(chunks[0]).keys()
193-
metas = ctx.get_chunks_meta(
194-
[chunk.key for chunk in chunks], fields=fields
195-
)
196-
for chunk, meta in zip(chunks, metas):
197-
chunk.params = {field: meta[field] for field in fields}
198-
tileable.refresh_params()
199187

200188
function = op.function
201189
function_args = replace_objects(op.function_args, mapping)

mars/tensor/base/shape.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -94,13 +94,7 @@ def tile(cls, op):
9494

9595
@classmethod
9696
def execute(cls, ctx, op):
97-
chunk_idx_to_chunk_shapes = {
98-
c.index: cm["shape"]
99-
for c, cm in zip(
100-
op.inputs,
101-
ctx.get_chunks_meta([c.key for c in op.inputs], fields=["shape"]),
102-
)
103-
}
97+
chunk_idx_to_chunk_shapes = dict((c.index, c.shape) for c in op.inputs)
10498
nsplits = calc_nsplits(chunk_idx_to_chunk_shapes)
10599
shape = tuple(sum(ns) for ns in nsplits)
106100
for o, s in zip(op.outputs, shape):

mars/tensor/base/tests/test_base_execution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,6 +1755,7 @@ def test_trapz_execution(setup):
17551755
np.testing.assert_almost_equal(result, expected)
17561756

17571757

1758+
@pytest.mark.ray_dag
17581759
def test_shape(setup):
17591760
raw = np.random.RandomState(0).rand(4, 3)
17601761
x = mt.tensor(raw, chunk_size=2)

0 commit comments

Comments
 (0)