Skip to content

Commit 4356fa1

Browse files
authored
Supports adding Mars extensions via setup entrypoints (#2589)
1 parent ea8dc9f commit 4356fa1

File tree

6 files changed

+178
-0
lines changed

6 files changed

+178
-0
lines changed

mars/core/entrypoints.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import warnings
17+
import functools
18+
19+
from pkg_resources import iter_entry_points
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
# from https://github.com/numba/numba/blob/master/numba/core/entrypoints.py
25+
# Must put this here to avoid extensions re-triggering initialization
26+
@functools.lru_cache(maxsize=None)
27+
def init_extension_entrypoints():
28+
"""Execute all `mars_extensions` entry points with the name `init`
29+
If extensions have already been initialized, this function does nothing.
30+
"""
31+
for entry_point in iter_entry_points("mars_extensions", "init"):
32+
logger.info("Loading extension: %s", entry_point)
33+
try:
34+
func = entry_point.load()
35+
func()
36+
except Exception as e:
37+
msg = "Mars extension module '{}' failed to load due to '{}({})'."
38+
warnings.warn(
39+
msg.format(entry_point.module_name, type(e).__name__, str(e)),
40+
stacklevel=2,
41+
)
42+
logger.info("Extension loading failed for: %s", entry_point)
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Copyright 1999-2021 Alibaba Group Holding Ltd.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import types
17+
import warnings
18+
import pkg_resources
19+
20+
21+
class _DummyClass(object):
22+
def __init__(self, value):
23+
self.value = value
24+
25+
def __repr__(self):
26+
return "_DummyClass(%f, %f)" % self.value
27+
28+
29+
def test_init_entrypoint():
30+
# FIXME: Python 2 workaround because nonlocal doesn't exist
31+
counters = {"init": 0}
32+
33+
def init_function():
34+
counters["init"] += 1
35+
36+
mod = types.ModuleType("_test_mars_extension")
37+
mod.init_func = init_function
38+
39+
try:
40+
# will remove this module at the end of the test
41+
sys.modules[mod.__name__] = mod
42+
43+
# We are registering an entry point using the "mars" package
44+
# ("distribution" in pkg_resources-speak) itself, though these are
45+
# normally registered by other packages.
46+
dist = "pymars"
47+
entrypoints = pkg_resources.get_entry_map(dist)
48+
my_entrypoint = pkg_resources.EntryPoint(
49+
"init", # name of entry point
50+
mod.__name__, # module with entry point object
51+
attrs=["init_func"], # name of entry point object
52+
dist=pkg_resources.get_distribution(dist),
53+
)
54+
entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint
55+
56+
from .. import entrypoints
57+
58+
# Allow reinitialization
59+
entrypoints.init_extension_entrypoints.cache_clear()
60+
61+
entrypoints.init_extension_entrypoints()
62+
63+
# was our init function called?
64+
assert counters["init"] == 1
65+
66+
# ensure we do not initialize twice
67+
entrypoints.init_extension_entrypoints()
68+
assert counters["init"] == 1
69+
finally:
70+
# remove fake module
71+
if mod.__name__ in sys.modules:
72+
del sys.modules[mod.__name__]
73+
74+
75+
def test_entrypoint_tolerance():
76+
# FIXME: Python 2 workaround because nonlocal doesn't exist
77+
counters = {"init": 0}
78+
79+
def init_function():
80+
counters["init"] += 1
81+
raise ValueError("broken")
82+
83+
mod = types.ModuleType("_test_mars_bad_extension")
84+
mod.init_func = init_function
85+
86+
try:
87+
# will remove this module at the end of the test
88+
sys.modules[mod.__name__] = mod
89+
90+
# We are registering an entry point using the "mars" package
91+
# ("distribution" in pkg_resources-speak) itself, though these are
92+
# normally registered by other packages.
93+
dist = "pymars"
94+
entrypoints = pkg_resources.get_entry_map(dist)
95+
my_entrypoint = pkg_resources.EntryPoint(
96+
"init", # name of entry point
97+
mod.__name__, # module with entry point object
98+
attrs=["init_func"], # name of entry point object
99+
dist=pkg_resources.get_distribution(dist),
100+
)
101+
entrypoints.setdefault("mars_extensions", {})["init"] = my_entrypoint
102+
103+
from .. import entrypoints
104+
105+
# Allow reinitialization
106+
entrypoints.init_extension_entrypoints.cache_clear()
107+
108+
with warnings.catch_warnings(record=True) as w:
109+
entrypoints.init_extension_entrypoints()
110+
111+
bad_str = "Mars extension module '_test_mars_bad_extension'"
112+
for x in w:
113+
if bad_str in str(x):
114+
break
115+
else:
116+
raise ValueError("Expected warning message not found")
117+
118+
# was our init function called?
119+
assert counters["init"] == 1
120+
121+
finally:
122+
# remove fake module
123+
if mod.__name__ in sys.modules:
124+
del sys.modules[mod.__name__]

mars/deploy/oscar/local.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import numpy as np
2323

2424
from ... import oscar as mo
25+
from ...core.entrypoints import init_extension_entrypoints
2526
from ...lib.aio import get_isolation, stop_isolation
2627
from ...resource import cpu_count, cuda_count
2728
from ...services import NodeRole
@@ -111,6 +112,8 @@ def __init__(
111112
web: Union[bool, str] = "auto",
112113
timeout: float = None,
113114
):
115+
# load third party extensions.
116+
init_extension_entrypoints()
114117
# load config file to dict.
115118
if not config or isinstance(config, str):
116119
config = load_config(config)

mars/deploy/oscar/ray.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import Union, Dict, List, Optional, AsyncGenerator
2121

2222
from ... import oscar as mo
23+
from ...core.entrypoints import init_extension_entrypoints
2324
from ...oscar.backends.ray.driver import RayActorDriver
2425
from ...oscar.backends.ray.utils import (
2526
process_placement_to_address,
@@ -371,6 +372,8 @@ def __init__(
371372
worker_mem: int = 32 * 1024 ** 3,
372373
config: Union[str, Dict] = None,
373374
):
375+
# load third party extensions.
376+
init_extension_entrypoints()
374377
self._cluster_name = cluster_name
375378
self._supervisor_mem = supervisor_mem
376379
self._worker_num = worker_num

mars/deploy/oscar/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from ... import oscar as mo
3737
from ...config import options
3838
from ...core import ChunkType, TileableType, TileableGraph, enter_mode
39+
from ...core.entrypoints import init_extension_entrypoints
3940
from ...core.operand import Fetch
4041
from ...lib.aio import (
4142
alru_cache,
@@ -1869,6 +1870,8 @@ def new_session(
18691870
new: bool = True,
18701871
**kwargs,
18711872
) -> AbstractSession:
1873+
# load third party extensions.
1874+
init_extension_entrypoints()
18721875
ensure_isolation_created(kwargs)
18731876

18741877
if address is None:

mars/oscar/backends/pool.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from abc import ABC, ABCMeta, abstractmethod
2424
from typing import Dict, List, Type, TypeVar, Coroutine, Callable, Union, Optional
2525

26+
from ...core.entrypoints import init_extension_entrypoints
2627
from ...utils import implements, to_binary
2728
from ...utils import lazy_import, register_asyncio_task_timeout_detector
2829
from ..api import Actor
@@ -141,6 +142,8 @@ def __init__(
141142
self._asyncio_task_timeout_detector_task = (
142143
register_asyncio_task_timeout_detector()
143144
)
145+
# load third party extensions.
146+
init_extension_entrypoints()
144147

145148
@property
146149
def router(self):

0 commit comments

Comments
 (0)