Skip to content

Commit f892599

Browse files
committed
split gem api from deployment service and sink ray operations to ray service
1 parent c0b08d0 commit f892599

File tree

6 files changed

+127
-106
lines changed

6 files changed

+127
-106
lines changed

rock/admin/core/ray_service.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
import ray
23
import time
34

@@ -34,6 +35,31 @@ def increment_ray_request_count(self):
3435

3536
def get_ray_rwlock(self):
3637
return self._ray_rwlock
38+
39+
async def async_ray_get_actor(self, sandbox_id: str):
40+
"""Async wrapper for ray.get_actor() using asyncio.to_thread for non-blocking execution."""
41+
self.increment_ray_request_count()
42+
try:
43+
result = await asyncio.to_thread(ray.get_actor, sandbox_id, namespace=self._config.namespace)
44+
except ValueError as e:
45+
logger.error(f"ray get actor, actor {sandbox_id} not exist", exc_info=e)
46+
raise e
47+
except Exception as e:
48+
logger.error("ray get actor failed", exc_info=e)
49+
error_msg = str(e.args[0]) if len(e.args) > 0 else f"ray get actor failed, {str(e)}"
50+
raise Exception(error_msg)
51+
return result
52+
53+
async def async_ray_get(self, ray_future: ray.ObjectRef):
54+
"""Async wrapper for ray.get() using asyncio.to_thread for non-blocking execution."""
55+
self.increment_ray_request_count()
56+
try:
57+
result = await asyncio.to_thread(ray.get, ray_future, timeout=60)
58+
except Exception as e:
59+
logger.error("ray get failed", exc_info=e)
60+
error_msg = str(e.args[0]) if len(e.args) > 0 else f"ray get failed, {str(e)}"
61+
raise Exception(error_msg)
62+
return result
3763

3864

3965
def _setup_ray_reconnect_scheduler(self):

rock/sandbox/gem_manager.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
from rock.config import RockConfig
1717
from rock.deployments.config import DockerDeploymentConfig
1818
from rock.sandbox.sandbox_manager import SandboxManager
19+
from rock.sandbox.service.env_service import RayEnvService
1920
from rock.utils.providers import RedisProvider
2021
from rock.admin.core.ray_service import RayService
2122

2223

2324
class GemManager(SandboxManager):
25+
_env_service: RayEnvService
2426
def __init__(
2527
self,
2628
rock_config: RockConfig,
@@ -30,6 +32,7 @@ def __init__(
3032
enable_runtime_auto_clear: bool = False,
3133
):
3234
super().__init__(rock_config, redis_provider, ray_namespace, ray_service, enable_runtime_auto_clear)
35+
self._env_service = RayEnvService(ray_namespace=ray_namespace, ray_service=ray_service)
3336

3437
async def env_make(self, env_id: str) -> EnvMakeResponse:
3538
config = DockerDeploymentConfig(image=env_vars.ROCK_ENVHUB_DEFAULT_DOCKER_IMAGE)
@@ -51,22 +54,22 @@ async def wait_until_alive(sandbox_id: str, interval: float = 1.0):
5154
except asyncio.TimeoutError:
5255
raise Exception("Sandbox startup timeout after 300s")
5356

54-
make_response = await self._deployment_service.env_make(
57+
make_response = await self._env_service.env_make(
5558
EnvMakeRequest(
5659
env_id=env_id,
5760
sandbox_id=sandbox_start_response.sandbox_id,
5861
)
5962
)
6063
return make_response
61-
64+
6265
async def env_step(self, request: EnvStepRequest) -> EnvStepResponse:
63-
return await self._deployment_service.env_step(request)
66+
return await self._env_service.env_step(request)
6467

6568
async def env_reset(self, request: EnvResetRequest) -> EnvResetResponse:
66-
return await self._deployment_service.env_reset(request)
69+
return await self._env_service.env_reset(request)
6770

6871
async def env_close(self, request: EnvCloseRequest) -> EnvCloseResponse:
69-
return await self._deployment_service.env_close(request)
72+
return await self._env_service.env_close(request)
7073

7174
async def env_list(self, sandbox_id: str) -> EnvListResponse:
72-
return await self._deployment_service.env_list(sandbox_id)
75+
return await self._env_service.env_list(sandbox_id)

rock/sandbox/service/deployment_service.py

Lines changed: 16 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from abc import abstractmethod
2-
import asyncio
3-
from rock.actions.envs.request import EnvCloseRequest, EnvMakeRequest, EnvResetRequest, EnvStepRequest
4-
from rock.actions.envs.response import EnvCloseResponse, EnvListResponse, EnvMakeResponse, EnvResetResponse, EnvStepResponse
2+
53
from rock.actions.sandbox.response import CommandResponse, State, SystemResourceMetrics
64
from rock.actions.sandbox.sandbox_info import SandboxInfo
75
from rock.admin.core.ray_service import RayService
@@ -50,26 +48,6 @@ async def get_sandbox_statistics(self, sandbox_id: str):
5048
async def commit(self, sandbox_id: str, image_tag: str, username: str, password: str) -> CommandResponse:
5149
...
5250

53-
@abstractmethod
54-
async def env_step(self, *args, **kwargs):
55-
...
56-
57-
@abstractmethod
58-
async def env_make(self, *args, **kwargs):
59-
...
60-
61-
@abstractmethod
62-
async def env_reset(self, *args, **kwargs):
63-
...
64-
65-
@abstractmethod
66-
async def env_list(self, *args, **kwargs):
67-
...
68-
69-
@abstractmethod
70-
async def env_close(self, *args, **kwargs):
71-
...
72-
7351
@abstractmethod
7452
async def collect_system_resource_metrics(self) -> SystemResourceMetrics:
7553
...
@@ -85,36 +63,10 @@ def _get_actor_name(self, sandbox_id):
8563

8664
async def is_deployment_alive(self, sandbox_id) -> bool:
8765
try:
88-
actor: SandboxActor = await self.async_ray_get_actor(sandbox_id)
66+
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
8967
except ValueError:
9068
return False
91-
return await self.async_ray_get(actor.is_alive.remote())
92-
93-
async def async_ray_get_actor(self, sandbox_id: str):
94-
"""Async wrapper for ray.get_actor() using asyncio.to_thread for non-blocking execution."""
95-
self._ray_service.increment_ray_request_count()
96-
try:
97-
actor_name = self._get_actor_name(sandbox_id)
98-
result = await asyncio.to_thread(ray.get_actor, actor_name, namespace=self._ray_namespace)
99-
except ValueError as e:
100-
logger.error(f"ray get actor, actor {sandbox_id} not exist", exc_info=e)
101-
raise e
102-
except Exception as e:
103-
logger.error("ray get actor failed", exc_info=e)
104-
error_msg = str(e.args[0]) if len(e.args) > 0 else f"ray get actor failed, {str(e)}"
105-
raise Exception(error_msg)
106-
return result
107-
108-
async def async_ray_get(self, ray_future: ray.ObjectRef):
109-
"""Async wrapper for ray.get() using asyncio.to_thread for non-blocking execution."""
110-
self._ray_service.increment_ray_request_count()
111-
try:
112-
result = await asyncio.to_thread(ray.get, ray_future, timeout=60)
113-
except Exception as e:
114-
logger.error("ray get failed", exc_info=e)
115-
error_msg = str(e.args[0]) if len(e.args) > 0 else f"ray get failed, {str(e)}"
116-
raise Exception(error_msg)
117-
return result
69+
return await self._ray_service.async_ray_get(actor.is_alive.remote())
11870

11971
async def submit(self, config: DockerDeploymentConfig, user_info: dict) -> SandboxInfo:
12072
async with self._ray_service.get_ray_rwlock().read_lock():
@@ -127,7 +79,7 @@ async def submit(self, config: DockerDeploymentConfig, user_info: dict) -> Sandb
12779
sandbox_actor.set_user_id.remote(user_id)
12880
sandbox_actor.set_experiment_id.remote(experiment_id)
12981
sandbox_actor.set_namespace.remote(namespace)
130-
sandbox_info: SandboxInfo = await self.async_ray_get(sandbox_actor.sandbox_info.remote())
82+
sandbox_info: SandboxInfo = await self._ray_service.async_ray_get(sandbox_actor.sandbox_info.remote())
13183
sandbox_info["user_id"] = user_id
13284
sandbox_info["experiment_id"] = experiment_id
13385
sandbox_info["namespace"] = namespace
@@ -155,78 +107,44 @@ def _generate_actor_options(self, config: DockerDeploymentConfig) -> dict:
155107

156108
async def stop(self, sandbox_id: str):
157109
async with self._ray_service.get_ray_rwlock().read_lock():
158-
actor: SandboxActor = await self.async_ray_get_actor(sandbox_id)
159-
await self.async_ray_get(actor.stop.remote())
110+
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
111+
await self._ray_service.async_ray_get(actor.stop.remote())
160112
logger.info(f"run time stop over {sandbox_id}")
161113
ray.kill(actor)
162114

163115
async def get_status(self, sandbox_id: str) -> SandboxInfo:
164116
async with self._ray_service.get_ray_rwlock().read_lock():
165-
actor: SandboxActor = await self.async_ray_get_actor(sandbox_id)
166-
sandbox_info: SandboxInfo = await self.async_ray_get(actor.sandbox_info.remote())
167-
remote_status: ServiceStatus = await self.async_ray_get(actor.get_status.remote())
117+
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
118+
sandbox_info: SandboxInfo = await self._ray_service.async_ray_get(actor.sandbox_info.remote())
119+
remote_status: ServiceStatus = await self._ray_service.async_ray_get(actor.get_status.remote())
168120
sandbox_info["phases"] = remote_status.phases
169121
sandbox_info["port_mapping"] = remote_status.get_port_mapping()
170-
alive = await self.async_ray_get(actor.is_alive.remote())
122+
alive = await self._ray_service.async_ray_get(actor.is_alive.remote())
171123
sandbox_info["alive"] = alive.is_alive
172124
if alive.is_alive:
173125
sandbox_info["state"] = State.RUNNING
174126
return sandbox_info
175127

176128
async def get_mount(self, sandbox_id: str):
177129
with self._ray_service.get_ray_rwlock().read_lock():
178-
actor = await self.async_ray_get_actor(sandbox_id)
179-
result = await self.async_ray_get(actor.get_mount.remote())
130+
actor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
131+
result = await self._ray_service.async_ray_get(actor.get_mount.remote())
180132
logger.info(f"get_mount: {result}")
181133
return result
182134

183135
async def get_sandbox_statistics(self, sandbox_id: str):
184136
async with self._ray_service.get_ray_rwlock().read_lock():
185-
actor = await self.async_ray_get_actor(sandbox_id)
186-
result = await self.async_ray_get(actor.get_sandbox_statistics.remote())
137+
actor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
138+
result = await self._ray_service.async_ray_get(actor.get_sandbox_statistics.remote())
187139
logger.info(f"get_sandbox_statistics: {result}")
188140
return result
189141

190142
async def commit(self, sandbox_id) -> CommandResponse:
191143
with self._ray_service.get_ray_rwlock().read_lock():
192-
actor = await self.async_ray_get_actor(sandbox_id)
193-
result = await self.async_ray_get(actor.commit.remote())
144+
actor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
145+
result = await self._ray_service.async_ray_get(actor.commit.remote())
194146
logger.info(f"commit: {result}")
195147
return result
196-
197-
async def env_step(self, request: EnvStepRequest) -> EnvStepResponse:
198-
sandbox_id = request.sandbox_id
199-
actor: SandboxActor = await self.async_ray_get_actor(sandbox_id)
200-
result = await self.async_ray_get(actor.env_step.remote(request))
201-
logger.info(f"env_step: {result}")
202-
return result
203-
204-
async def env_make(self, request: EnvMakeRequest) -> EnvMakeResponse:
205-
sandbox_id = request.sandbox_id
206-
actor: SandboxActor = await self.async_ray_get_actor(sandbox_id)
207-
result = await self.async_ray_get(actor.env_make.remote(request))
208-
logger.info(f"env_make: {result}")
209-
return result
210-
211-
async def env_reset(self, request: EnvResetRequest) -> EnvResetResponse:
212-
sandbox_id = request.sandbox_id
213-
actor: SandboxActor = await self.async_ray_get_actor(sandbox_id)
214-
result = await self.async_ray_get(actor.env_reset.remote(request))
215-
logger.info(f"env_reset: {result}")
216-
return result
217-
218-
async def env_close(self, request: EnvCloseRequest) -> EnvCloseResponse:
219-
sandbox_id = request.sandbox_id
220-
actor: SandboxActor = await self.async_ray_get_actor(sandbox_id)
221-
result = await self.async_ray_get(actor.env_close.remote(request))
222-
logger.info(f"env_close: {result}")
223-
return result
224-
225-
async def env_list(self, sandbox_id) -> EnvListResponse:
226-
actor: SandboxActor = await self.async_ray_get_actor(sandbox_id)
227-
result = await self.async_ray_get(actor.env_list.remote())
228-
logger.info(f"env_list: {result}")
229-
return result
230148

231149
async def collect_system_resource_metrics(self) -> SystemResourceMetrics:
232150
"""Collect system resource metrics"""
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from abc import ABC, abstractmethod
2+
3+
from rock.actions.envs.request import EnvCloseRequest, EnvMakeRequest, EnvResetRequest, EnvStepRequest
4+
from rock.actions.envs.response import EnvCloseResponse, EnvListResponse, EnvMakeResponse, EnvResetResponse, EnvStepResponse
5+
from rock.admin.core.ray_service import RayService
6+
from rock.logger import init_logger
7+
from rock.sandbox.sandbox_actor import SandboxActor
8+
9+
logger = init_logger(__name__)
10+
11+
12+
class AbstractEnvService(ABC):
13+
@abstractmethod
14+
async def env_step(self, request: EnvStepRequest) -> EnvStepResponse:
15+
...
16+
17+
@abstractmethod
18+
async def env_make(self, request: EnvMakeRequest) -> EnvMakeResponse:
19+
...
20+
21+
@abstractmethod
22+
async def env_reset(self, request: EnvResetRequest) -> EnvResetResponse:
23+
...
24+
25+
@abstractmethod
26+
async def env_close(self, request: EnvCloseRequest) -> EnvCloseResponse:
27+
...
28+
29+
@abstractmethod
30+
async def env_list(self, sandbox_id) -> EnvListResponse:
31+
...
32+
33+
34+
class RayEnvService(AbstractEnvService):
35+
def __init__(self, ray_namespace: str, ray_service: RayService):
36+
self._ray_namespace = ray_namespace
37+
self._ray_service = ray_service
38+
39+
def _get_actor_name(self, sandbox_id):
40+
return sandbox_id
41+
42+
async def env_step(self, request: EnvStepRequest) -> EnvStepResponse:
43+
sandbox_id = request.sandbox_id
44+
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
45+
result = await self._ray_service.async_ray_get(actor.env_step.remote(request))
46+
logger.info(f"env_step: {result}")
47+
return result
48+
49+
async def env_make(self, request: EnvMakeRequest) -> EnvMakeResponse:
50+
sandbox_id = request.sandbox_id
51+
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
52+
result = await self._ray_service.async_ray_get(actor.env_make.remote(request))
53+
logger.info(f"env_make: {result}")
54+
return result
55+
56+
async def env_reset(self, request: EnvResetRequest) -> EnvResetResponse:
57+
sandbox_id = request.sandbox_id
58+
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
59+
result = await self._ray_service.async_ray_get(actor.env_reset.remote(request))
60+
logger.info(f"env_reset: {result}")
61+
return result
62+
63+
async def env_close(self, request: EnvCloseRequest) -> EnvCloseResponse:
64+
sandbox_id = request.sandbox_id
65+
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
66+
result = await self._ray_service.async_ray_get(actor.env_close.remote(request))
67+
logger.info(f"env_close: {result}")
68+
return result
69+
70+
async def env_list(self, sandbox_id) -> EnvListResponse:
71+
actor: SandboxActor = await self._ray_service.async_ray_get_actor(self._get_actor_name(sandbox_id))
72+
result = await self._ray_service.async_ray_get(actor.env_list.remote())
73+
logger.info(f"env_list: {result}")
74+
return result

tests/unit/sandbox/service/test_deployment_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
async def test_get_actor_not_exist_raises_value_error(ray_deployment_service):
88
sandbox_id = "unknown"
99
with pytest.raises(Exception) as exc_info:
10-
await ray_deployment_service.async_ray_get_actor(sandbox_id)
10+
await ray_deployment_service._ray_service.async_ray_get_actor(sandbox_id)
1111
assert exc_info.type == ValueError
1212

1313

tests/unit/sandbox/test_sandbox_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ async def test_ray_actor_is_alive(sandbox_manager):
6767

6868
assert await wait_sandbox_instance_alive(sandbox_manager, response.sandbox_id)
6969

70-
sandbox_actor = await sandbox_manager._deployment_service.async_ray_get_actor(response.sandbox_id)
70+
sandbox_actor = await sandbox_manager._deployment_service._ray_service.async_ray_get_actor(response.sandbox_id)
7171
ray.kill(sandbox_actor)
7272

7373
assert not await sandbox_manager._deployment_service.is_deployment_alive(response.sandbox_id)

0 commit comments

Comments
 (0)