|
25 | 25 | from rock.utils.format import parse_memory_size |
26 | 26 | from rock.utils.providers import RedisProvider |
27 | 27 | from rock.admin.core.ray_service import RayService |
| 28 | +from rock.rocklet import __version__ as swe_version |
| 29 | +from rock.sandbox import __version__ as gateway_version |
28 | 30 |
|
29 | 31 | logger = init_logger(__name__) |
30 | 32 |
|
@@ -58,7 +60,6 @@ async def submit(self, config: DeploymentConfig, user_info: dict = {}): |
58 | 60 | async with self._ray_service.get_ray_rwlock().read_lock(): |
59 | 61 | deployment_config: DeploymentConfig = await self.deployment_manager.init_config(config) |
60 | 62 | sandbox_id = deployment_config.container_name |
61 | | - # deployment: AbstractDeployment = deployment_config.get_deployment() |
62 | 63 | self.validate_sandbox_spec(self.rock_config.runtime, config) |
63 | 64 | self._sandbox_meta[sandbox_id] = {"image": deployment_config.image} |
64 | 65 | sandbox_info: SandboxInfo = await self._deployment_service.submit(deployment_config, user_info) |
@@ -121,26 +122,46 @@ async def _clear_redis_keys(self, sandbox_id): |
121 | 122 | @monitor_sandbox_operation() |
122 | 123 | async def get_status(self, sandbox_id) -> SandboxStatusResponse: |
123 | 124 | async with self._ray_service.get_ray_rwlock().read_lock(): |
124 | | - response: SandboxStatusResponse = await self._deployment_service.get_status(sandbox_id) |
125 | | - sandbox_info: SandboxInfo = self.get_info_from_response(response) |
| 125 | + deployment_info: SandboxInfo = await self._deployment_service.get_status(sandbox_id) |
| 126 | + sandbox_info: SandboxInfo = None |
126 | 127 | if self._redis_provider: |
| 128 | + sandbox_info = await self.build_sandbox_info_from_redis(sandbox_id) |
| 129 | + if sandbox_info is None: |
| 130 | + sandbox_info = deployment_info |
| 131 | + else: |
| 132 | + sandbox_info["state"] = deployment_info.get("state") |
127 | 133 | await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) |
128 | 134 | await self._update_expire_time(sandbox_id) |
129 | | - # logger.info(f"sandbox {sandbox_id} status is {remote_status}, write to redis") |
130 | | - return response |
131 | | - |
132 | | - def get_info_from_response(self, response: SandboxStatusResponse) -> SandboxInfo: |
133 | | - return SandboxInfo( |
134 | | - host_name=response.host_name, |
135 | | - host_ip=response.host_ip, |
136 | | - user_id=response.user_id, |
137 | | - experiment_id=response.experiment_id, |
138 | | - namespace=response.namespace, |
139 | | - sandbox_id=response.sandbox_id, |
140 | | - cpus=response.cpus, |
141 | | - memory=response.memory, |
142 | | - port_mapping=response.port_mapping, |
143 | | - ) |
| 135 | + remote_info = {k: v for k, v in deployment_info.items() if k in ['status', 'port_mapping', 'alive']} |
| 136 | + sandbox_info.update(remote_info) |
| 137 | + logger.info(f"sandbox {sandbox_id} status is {sandbox_info}, write to redis") |
| 138 | + else: |
| 139 | + sandbox_info = deployment_info |
| 140 | + |
| 141 | + return SandboxStatusResponse( |
| 142 | + sandbox_id=sandbox_id, |
| 143 | + status=sandbox_info.get("status"), |
| 144 | + state=sandbox_info.get("state"), |
| 145 | + port_mapping=sandbox_info.get("port_mapping"), |
| 146 | + host_name=sandbox_info.get("host_name"), |
| 147 | + host_ip=sandbox_info.get("host_ip"), |
| 148 | + is_alive=sandbox_info.get("alive"), |
| 149 | + image=sandbox_info.get("image"), |
| 150 | + swe_rex_version=swe_version, |
| 151 | + gateway_version=gateway_version, |
| 152 | + user_id=sandbox_info.get("user_id"), |
| 153 | + experiment_id=sandbox_info.get("experiment_id"), |
| 154 | + namespace=sandbox_info.get("namespace"), |
| 155 | + cpus=sandbox_info.get("cpus"), |
| 156 | + memory=sandbox_info.get("memory"), |
| 157 | + ) |
| 158 | + |
| 159 | + async def build_sandbox_info_from_redis(self, sandbox_id: str) -> SandboxInfo | None: |
| 160 | + if self._redis_provider: |
| 161 | + sandbox_status = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$") |
| 162 | + if sandbox_status and len(sandbox_status) > 0: |
| 163 | + return sandbox_status[0] |
| 164 | + return None |
144 | 165 |
|
145 | 166 | async def create_session(self, request: CreateSessionRequest) -> CreateBashSessionResponse: |
146 | 167 | sandbox_actor = await self._deployment_service.async_ray_get_actor(request.sandbox_id) |
|
0 commit comments