Skip to content

Commit aa23fd0

Browse files
authored
[Ray] Implements get_chunks_meta for Ray execution context (#3052)
1 parent bc37acf commit aa23fd0

File tree

6 files changed

+159
-61
lines changed

6 files changed

+159
-61
lines changed

mars/dataframe/base/tests/test_base_execution.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ def test_datetime_method_execution(setup):
682682
pd.testing.assert_series_equal(result, expected)
683683

684684

685+
@pytest.mark.ray_dag
685686
def test_isin_execution(setup):
686687
# one chunk in multiple chunks
687688
a = pd.Series([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

mars/services/task/execution/ray/context.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
import functools
1616
import inspect
1717
import logging
18+
from dataclasses import asdict
1819
from typing import Union, Dict, List
1920

2021
from .....core.context import Context
22+
from .....storage.base import StorageLevel
2123
from .....utils import implements, lazy_import
2224
from ....context import ThreadedServiceContext
2325

@@ -116,9 +118,10 @@ def destroy_remote_object(self, name: str):
116118
class RayExecutionContext(_RayRemoteObjectContext, ThreadedServiceContext):
117119
"""The context for tiling."""
118120

119-
def __init__(self, task_context: Dict, *args, **kwargs):
121+
def __init__(self, task_context: Dict, task_chunks_meta: Dict, *args, **kwargs):
120122
super().__init__(*args, **kwargs)
121123
self._task_context = task_context
124+
self._task_chunks_meta = task_chunks_meta
122125

123126
@implements(Context.get_chunks_result)
124127
def get_chunks_result(self, data_keys: List[str]) -> List:
@@ -128,11 +131,60 @@ def get_chunks_result(self, data_keys: List[str]) -> List:
128131
logger.info("Got %s chunks result.", len(result))
129132
return result
130133

134+
@implements(Context.get_chunks_meta)
135+
def get_chunks_meta(
136+
self, data_keys: List[str], fields: List[str] = None, error="raise"
137+
) -> List[Dict]:
138+
result = []
139+
# TODO(fyrestone): Support get_chunks_meta from meta service if needed.
140+
for key in data_keys:
141+
chunk_meta = self._task_chunks_meta[key]
142+
meta = asdict(chunk_meta)
143+
meta = {f: meta.get(f) for f in fields}
144+
result.append(meta)
145+
return result
146+
131147

132148
# TODO(fyrestone): Implement more APIs for Ray.
133149
class RayExecutionWorkerContext(_RayRemoteObjectContext, dict):
134150
"""The context for executing operands."""
135151

136-
@staticmethod
137-
def new_custom_log_dir():
152+
@classmethod
153+
@implements(Context.new_custom_log_dir)
154+
def new_custom_log_dir(cls):
155+
logger.info(
156+
"%s does not support register_custom_log_path / new_custom_log_dir",
157+
cls.__name__,
158+
)
138159
return None
160+
161+
@staticmethod
162+
@implements(Context.register_custom_log_path)
163+
def register_custom_log_path(
164+
session_id: str,
165+
tileable_op_key: str,
166+
chunk_op_key: str,
167+
worker_address: str,
168+
log_path: str,
169+
):
170+
raise NotImplementedError
171+
172+
@classmethod
173+
@implements(Context.set_progress)
174+
def set_progress(cls, progress: float):
175+
logger.info(
176+
"%s does not support set_running_operand_key / set_progress", cls.__name__
177+
)
178+
179+
@staticmethod
180+
@implements(Context.set_running_operand_key)
181+
def set_running_operand_key(session_id: str, op_key: str):
182+
raise NotImplementedError
183+
184+
@classmethod
185+
@implements(Context.get_storage_info)
186+
def get_storage_info(
187+
cls, address: str = None, level: StorageLevel = StorageLevel.MEMORY
188+
):
189+
logger.info("%s does not support get_storage_info", cls.__name__)
190+
return {}

mars/services/task/execution/ray/executor.py

Lines changed: 63 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import asyncio
1616
import functools
1717
import logging
18+
from dataclasses import dataclass
1819
from typing import List, Dict, Any, Set
1920
from .....core import ChunkGraph, Chunk, TileContext
2021
from .....core.context import set_context
@@ -30,6 +31,7 @@
3031
from .....serialization import serialize, deserialize
3132
from .....typing import BandType
3233
from .....utils import (
34+
calc_data_size,
3335
lazy_import,
3436
get_chunk_params,
3537
get_chunk_key_to_data_keys,
@@ -56,6 +58,11 @@
5658
logger = logging.getLogger(__name__)
5759

5860

61+
@dataclass
62+
class _RayChunkMeta:
63+
memory_size: int
64+
65+
5966
class RayTaskState(RayRemoteObjectManager):
6067
@classmethod
6168
def gen_name(cls, task_id: str):
@@ -102,11 +109,14 @@ def execute_subtask(
102109
if output_meta_keys:
103110
output_meta = {}
104111
for chunk in subtask_chunk_graph.result_chunks:
105-
if chunk.key in output_meta_keys:
112+
chunk_key = chunk.key
113+
if chunk_key in output_meta_keys and chunk_key not in output_meta:
106114
if isinstance(chunk.op, Fuse):
107115
# fuse op
108116
chunk = chunk.chunk
109-
output_meta[chunk.key] = get_chunk_params(chunk)
117+
data = context[chunk_key]
118+
memory_size = calc_data_size(data)
119+
output_meta[chunk_key] = get_chunk_params(chunk), memory_size
110120
assert len(output_meta_keys) == len(output_meta)
111121
output_values.append(output_meta)
112122
output_values.extend(output.values())
@@ -125,6 +135,7 @@ def __init__(
125135
task: Task,
126136
tile_context: TileContext,
127137
task_context: Dict[str, "ray.ObjectRef"],
138+
task_chunks_meta: Dict[str, _RayChunkMeta],
128139
task_state_actor: "ray.actor.ActorHandle",
129140
lifecycle_api: LifecycleAPI,
130141
meta_api: MetaAPI,
@@ -133,6 +144,7 @@ def __init__(
133144
self._task = task
134145
self._tile_context = tile_context
135146
self._task_context = task_context
147+
self._task_chunks_meta = task_chunks_meta
136148
self._task_state_actor = task_state_actor
137149
self._ray_executor = self._get_ray_executor()
138150

@@ -166,12 +178,16 @@ async def create(
166178
.remote()
167179
)
168180
task_context = {}
169-
await cls._init_context(task_context, task_state_actor, session_id, address)
181+
task_chunks_meta = {}
182+
await cls._init_context(
183+
task_context, task_chunks_meta, task_state_actor, session_id, address
184+
)
170185
return cls(
171186
config,
172187
task,
173188
tile_context,
174189
task_context,
190+
task_chunks_meta,
175191
task_state_actor,
176192
lifecycle_api,
177193
meta_api,
@@ -183,6 +199,7 @@ def destroy(self):
183199
self._task = None
184200
self._tile_context = None
185201
self._task_context = None
202+
self._task_chunks_meta = None
186203
self._task_state_actor = None
187204
self._ray_executor = None
188205

@@ -207,7 +224,7 @@ async def _get_apis(cls, session_id: str, address: str):
207224
)
208225

209226
@staticmethod
210-
@functools.lru_cache(maxsize=1)
227+
@functools.lru_cache(maxsize=None) # Specify maxsize=None to make it faster
211228
def _get_ray_executor():
212229
# Export remote function once.
213230
return ray.remote(execute_subtask)
@@ -216,13 +233,15 @@ def _get_ray_executor():
216233
async def _init_context(
217234
cls,
218235
task_context: Dict[str, "ray.ObjectRef"],
236+
task_chunks_meta: Dict[str, _RayChunkMeta],
219237
task_state_actor: "ray.actor.ActorHandle",
220238
session_id: str,
221239
address: str,
222240
):
223241
loop = asyncio.get_running_loop()
224242
context = RayExecutionContext(
225243
task_context,
244+
task_chunks_meta,
226245
task_state_actor,
227246
session_id,
228247
address,
@@ -293,7 +312,9 @@ async def execute_subtask_graph(
293312
logger.info("Getting %s metas of stage %s.", meta_count, stage_id)
294313
meta_list = await asyncio.gather(*output_meta_object_refs)
295314
for meta in meta_list:
296-
key_to_meta.update(meta)
315+
for key, (params, memory_size) in meta.items():
316+
key_to_meta[key] = params
317+
self._task_chunks_meta[key] = _RayChunkMeta(memory_size=memory_size)
297318
assert len(key_to_meta) == len(result_meta_keys)
298319
logger.info("Got %s metas of stage %s.", meta_count, stage_id)
299320

@@ -304,9 +325,9 @@ async def execute_subtask_graph(
304325
chunk_key = chunk.key
305326
object_ref = task_context[chunk_key]
306327
output_object_refs.add(object_ref)
307-
chunk_meta = key_to_meta.get(chunk_key)
308-
if chunk_meta is not None:
309-
chunk_to_meta[chunk] = ExecutionChunkResult(chunk_meta, object_ref)
328+
chunk_params = key_to_meta.get(chunk_key)
329+
if chunk_params is not None:
330+
chunk_to_meta[chunk] = ExecutionChunkResult(chunk_params, object_ref)
310331

311332
logger.info("Waiting for stage %s complete.", stage_id)
312333
# Patched the asyncio.to_thread for Python < 3.9 at mars/lib/aio/__init__.py
@@ -319,36 +340,42 @@ async def execute_subtask_graph(
319340
return chunk_to_meta
320341

321342
async def __aexit__(self, exc_type, exc_val, exc_tb):
322-
if exc_type is None:
323-
tileable_keys = []
324-
update_metas = []
325-
update_lifecycles = []
326-
for tileable in self._task.tileable_graph.result_tileables:
327-
tileable_keys.append(tileable.key)
328-
tileable = tileable.data if hasattr(tileable, "data") else tileable
329-
chunk_keys = []
330-
for chunk in self._tile_context[tileable].chunks:
331-
chunk_keys.append(chunk.key)
332-
if chunk.key in self._task_context:
333-
# Some tileable graph may have result chunks that not be executed,
334-
# for example:
335-
# r, b = cut(series, bins, retbins=True)
336-
# r_result = r.execute().fetch()
337-
# b_result = b.execute().fetch() <- This is the case
338-
object_ref = self._task_context[chunk.key]
339-
update_metas.append(
340-
self._meta_api.set_chunk_meta.delay(
341-
chunk,
342-
bands=[],
343-
object_ref=object_ref,
344-
)
343+
if exc_type is not None:
344+
return
345+
346+
# Update info if no exception occurs.
347+
tileable_keys = []
348+
update_metas = []
349+
update_lifecycles = []
350+
for tileable in self._task.tileable_graph.result_tileables:
351+
tileable_keys.append(tileable.key)
352+
tileable = tileable.data if hasattr(tileable, "data") else tileable
353+
chunk_keys = []
354+
for chunk in self._tile_context[tileable].chunks:
355+
chunk_key = chunk.key
356+
chunk_keys.append(chunk_key)
357+
if chunk_key in self._task_context:
358+
# Some tileable graph may have result chunks that not be executed,
359+
# for example:
360+
# r, b = cut(series, bins, retbins=True)
361+
# r_result = r.execute().fetch()
362+
# b_result = b.execute().fetch() <- This is the case
363+
object_ref = self._task_context[chunk_key]
364+
chunk_meta = self._task_chunks_meta[chunk_key]
365+
update_metas.append(
366+
self._meta_api.set_chunk_meta.delay(
367+
chunk,
368+
bands=[],
369+
object_ref=object_ref,
370+
memory_size=chunk_meta.memory_size,
345371
)
346-
update_lifecycles.append(
347-
self._lifecycle_api.track.delay(tileable.key, chunk_keys)
348372
)
349-
await self._meta_api.set_chunk_meta.batch(*update_metas)
350-
await self._lifecycle_api.track.batch(*update_lifecycles)
351-
await self._lifecycle_api.incref_tileables(tileable_keys)
373+
update_lifecycles.append(
374+
self._lifecycle_api.track.delay(tileable.key, chunk_keys)
375+
)
376+
await self._meta_api.set_chunk_meta.batch(*update_metas)
377+
await self._lifecycle_api.track.batch(*update_lifecycles)
378+
await self._lifecycle_api.incref_tileables(tileable_keys)
352379

353380
async def get_available_band_resources(self) -> Dict[BandType, Resource]:
354381
if self._available_band_resources is None:

mars/services/task/execution/ray/fetcher.py

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import functools
1617
from collections import namedtuple
1718
from typing import Dict, List
1819

@@ -23,17 +24,28 @@
2324
_FetchInfo = namedtuple("FetchInfo", ["key", "object_ref", "conditions"])
2425

2526

27+
def _query_object_with_condition(o, conditions):
28+
try:
29+
return o.iloc[conditions]
30+
except AttributeError:
31+
return o[conditions]
32+
33+
2634
@register_fetcher_cls
2735
class RayFetcher(Fetcher):
2836
name = "ray"
2937
required_meta_keys = ("object_refs",)
3038

3139
def __init__(self, **kwargs):
32-
_make_query_function_remote()
33-
3440
self._fetch_info_list = []
3541
self._no_conditions = True
3642

43+
@staticmethod
44+
@functools.lru_cache(maxsize=None) # Specify maxsize=None to make it faster
45+
def _remote_query_object_with_condition():
46+
# Export remote function once.
47+
return ray.remote(_query_object_with_condition)
48+
3749
async def append(self, chunk_key: str, chunk_meta: Dict, conditions: List = None):
3850
if conditions is not None:
3951
self._no_conditions = False
@@ -51,24 +63,7 @@ async def get(self):
5163
if fetch_info.conditions is None:
5264
refs[index] = fetch_info.object_ref
5365
else:
54-
refs[index] = _remote_query_object_with_condition.remote(
66+
refs[index] = self._remote_query_object_with_condition().remote(
5567
fetch_info.object_ref, fetch_info.conditions
5668
)
5769
return await asyncio.gather(*refs)
58-
59-
60-
def _query_object_with_condition(o, conditions):
61-
try:
62-
return o.iloc[conditions]
63-
except AttributeError:
64-
return o[conditions]
65-
66-
67-
_remote_query_object_with_condition = None
68-
69-
70-
def _make_query_function_remote():
71-
global _remote_query_object_with_condition
72-
73-
if _remote_query_object_with_condition is None:
74-
_remote_query_object_with_condition = ray.remote(_query_object_with_condition)

0 commit comments

Comments
 (0)