Skip to content

Commit 44ed47b

Browse files
author
Xuye (Chris) Qin
authored
Fix recursive_tile that it may cause duplicated tile for one tileable (#3021)
1 parent 4be160f commit 44ed47b

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

mars/core/entity/tests/test_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_recursive_tile(setup):
4444
d2 = mt.random.rand(10, chunk_size=5)
4545
op = _TestOperand()
4646
t = op.new_tensor([d1, d2], dtype=d1.dtype, shape=(20,), order=d1.order)
47-
t.execute()
47+
t.execute(extra_config={"check_duplicated_operand_keys": True})
4848

4949

5050
class _TestOperandWithDuplicatedSubmission(TensorOperand, TensorOperandMixin):

mars/core/entity/utils.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,24 @@ def recursive_tile(
7676
q = [t for t in to_tile if t.is_coarse()]
7777
while q:
7878
t = q[-1]
79-
cs = [c for c in t.inputs if c.is_coarse()]
80-
if cs:
81-
q.extend(cs)
82-
continue
83-
for obj in handler.tile(t.op.outputs):
84-
to_update_inputs = []
85-
chunks = []
86-
for inp in t.op.inputs:
87-
chunks.extend(inp.chunks)
88-
if has_unknown_shape(inp):
89-
to_update_inputs.append(inp)
90-
if obj is None:
91-
yield chunks + to_update_inputs
92-
else:
93-
yield obj + to_update_inputs
79+
if t.is_coarse():
80+
# t may be put into q repeatedly,
81+
# so we check if it's tiled or not
82+
cs = [c for c in t.inputs if c.is_coarse()]
83+
if cs:
84+
q.extend(cs)
85+
continue
86+
for obj in handler.tile(t.op.outputs):
87+
to_update_inputs = []
88+
chunks = []
89+
for inp in t.op.inputs:
90+
chunks.extend(inp.chunks)
91+
if has_unknown_shape(inp):
92+
to_update_inputs.append(inp)
93+
if obj is None:
94+
yield chunks + to_update_inputs
95+
else:
96+
yield obj + to_update_inputs
9497
q.pop()
9598

9699
if not return_list:

mars/services/task/supervisor/tests/task_preprocessor.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,9 @@ def __init__(self, *args, **kwargs):
5151
for key in _check_args:
5252
check_options[key] = kwargs.get(key, True)
5353
self._check_options = check_options
54+
self._check_duplicated_operand_keys = bool(
55+
kwargs.get("check_duplicated_operand_keys")
56+
)
5457

5558
def _get_done(self):
5659
return super()._get_done()
@@ -145,6 +148,13 @@ def analyze(
145148
stage_id: str,
146149
op_to_bands: Dict[str, BandType] = None,
147150
) -> SubtaskGraph:
151+
# check if duplicated operand keys exist
152+
if self._check_duplicated_operand_keys and len(
153+
{c.key for c in chunk_graph}
154+
) < len(
155+
chunk_graph
156+
): # pragma: no cover
157+
raise AssertionError("Duplicated operands exist")
148158
# record shapes generated in tile
149159
for n in chunk_graph:
150160
self._raw_chunk_shapes[n.key] = getattr(n, "shape", None)

0 commit comments

Comments
 (0)