Skip to content

Commit 0b70425

Browse files
authored
Make Proxima work with latest Mars (#2599)
1 parent 4356fa1 commit 0b70425

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

mars/learn/proxima/core.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ def rechunk_tensor(tensor, chunk_size):
6969
if start_chunk_index == end_chunk_index:
7070
t = tensor.chunks[start_chunk_index]
7171
slice_op = TensorSlice(
72-
(
72+
[
7373
slice(
7474
offset - tensor_cumnrows[start_chunk_index],
7575
split + offset - tensor_cumnrows[end_chunk_index],
7676
),
7777
slice(None),
78-
),
78+
],
7979
dtype=t.dtype,
8080
)
8181
out_groups.append(
@@ -93,7 +93,7 @@ def rechunk_tensor(tensor, chunk_size):
9393
start_chunk = tensor.chunks[start_chunk_index]
9494
start_slice = int(offset - tensor_cumnrows[start_chunk_index])
9595
slice_op = TensorSlice(
96-
(slice(start_slice, None), slice(None)), dtype=start_chunk.dtype
96+
[slice(start_slice, None), slice(None)], dtype=start_chunk.dtype
9797
)
9898
chunks.append(
9999
slice_op.new_chunk(
@@ -107,7 +107,7 @@ def rechunk_tensor(tensor, chunk_size):
107107
end_chunk = tensor.chunks[end_chunk_index]
108108
end_slice = int(split + offset - tensor_cumnrows[end_chunk_index])
109109
slice_op_end = TensorSlice(
110-
(slice(None, end_slice), slice(None)), dtype=start_chunk.dtype
110+
[slice(None, end_slice), slice(None)], dtype=start_chunk.dtype
111111
)
112112
chunks.append(
113113
slice_op_end.new_chunk(

mars/services/task/analyzer/analyzer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,12 @@ def gen_subtask_graph(self) -> SubtaskGraph:
303303
"Assigned %s start chunks for task %s", len(start_ops), self._task.task_id
304304
)
305305

306+
# assign expect workers for those specified with `expect_worker`
307+
# skip `start_ops`, which have been assigned before
308+
for chunk in self._chunk_graph:
309+
if chunk not in start_ops and chunk.op.expect_worker is not None:
310+
chunk_to_bands[chunk] = self._to_band(chunk.op.expect_worker)
311+
306312
# fuse node
307313
if self._fuse_enabled:
308314
logger.debug("Start to fuse chunks for task %s", self._task.task_id)

0 commit comments

Comments
 (0)