Skip to content

Commit ddca701

Browse files
committed
fix test case: add rock auth in get status
1 parent 16af58e commit ddca701

File tree

3 files changed

+47
-40
lines changed

3 files changed

+47
-40
lines changed

rock/actions/sandbox/sandbox_info.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class SandboxInfo(TypedDict, total=False):
2020
create_user_gray_flag: bool
2121
cpus: float
2222
memory: str
23+
alive: bool
2324

2425

2526
class SandboxListItem(SandboxInfo):

rock/sandbox/sandbox_manager.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
from rock.utils.format import parse_memory_size
2626
from rock.utils.providers import RedisProvider
2727
from rock.admin.core.ray_service import RayService
28+
from rock.rocklet import __version__ as swe_version
29+
from rock.sandbox import __version__ as gateway_version
2830

2931
logger = init_logger(__name__)
3032

@@ -58,7 +60,6 @@ async def submit(self, config: DeploymentConfig, user_info: dict = {}):
5860
async with self._ray_service.get_ray_rwlock().read_lock():
5961
deployment_config: DeploymentConfig = await self.deployment_manager.init_config(config)
6062
sandbox_id = deployment_config.container_name
61-
# deployment: AbstractDeployment = deployment_config.get_deployment()
6263
self.validate_sandbox_spec(self.rock_config.runtime, config)
6364
self._sandbox_meta[sandbox_id] = {"image": deployment_config.image}
6465
sandbox_info: SandboxInfo = await self._deployment_service.submit(deployment_config, user_info)
@@ -121,26 +122,46 @@ async def _clear_redis_keys(self, sandbox_id):
121122
@monitor_sandbox_operation()
122123
async def get_status(self, sandbox_id) -> SandboxStatusResponse:
123124
async with self._ray_service.get_ray_rwlock().read_lock():
124-
response: SandboxStatusResponse = await self._deployment_service.get_status(sandbox_id)
125-
sandbox_info: SandboxInfo = self.get_info_from_response(response)
125+
deployment_info: SandboxInfo = await self._deployment_service.get_status(sandbox_id)
126+
sandbox_info: SandboxInfo = None
126127
if self._redis_provider:
128+
sandbox_info = await self.build_sandbox_info_from_redis(sandbox_id)
129+
if sandbox_info is None:
130+
sandbox_info = deployment_info
131+
else:
132+
sandbox_info["state"] = deployment_info.get("state")
127133
await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info)
128134
await self._update_expire_time(sandbox_id)
129-
# logger.info(f"sandbox {sandbox_id} status is {remote_status}, write to redis")
130-
return response
131-
132-
def get_info_from_response(self, response: SandboxStatusResponse) -> SandboxInfo:
133-
return SandboxInfo(
134-
host_name=response.host_name,
135-
host_ip=response.host_ip,
136-
user_id=response.user_id,
137-
experiment_id=response.experiment_id,
138-
namespace=response.namespace,
139-
sandbox_id=response.sandbox_id,
140-
cpus=response.cpus,
141-
memory=response.memory,
142-
port_mapping=response.port_mapping,
143-
)
135+
remote_info = {k: v for k, v in deployment_info.items() if k in ['status', 'port_mapping', 'alive']}
136+
sandbox_info.update(remote_info)
137+
logger.info(f"sandbox {sandbox_id} status is {sandbox_info}, write to redis")
138+
else:
139+
sandbox_info = deployment_info
140+
141+
return SandboxStatusResponse(
142+
sandbox_id=sandbox_id,
143+
status=sandbox_info.get("status"),
144+
state=sandbox_info.get("state"),
145+
port_mapping=sandbox_info.get("port_mapping"),
146+
host_name=sandbox_info.get("host_name"),
147+
host_ip=sandbox_info.get("host_ip"),
148+
is_alive=sandbox_info.get("alive"),
149+
image=sandbox_info.get("image"),
150+
swe_rex_version=swe_version,
151+
gateway_version=gateway_version,
152+
user_id=sandbox_info.get("user_id"),
153+
experiment_id=sandbox_info.get("experiment_id"),
154+
namespace=sandbox_info.get("namespace"),
155+
cpus=sandbox_info.get("cpus"),
156+
memory=sandbox_info.get("memory"),
157+
)
158+
159+
async def build_sandbox_info_from_redis(self, sandbox_id: str) -> SandboxInfo | None:
160+
if self._redis_provider:
161+
sandbox_status = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$")
162+
if sandbox_status and len(sandbox_status) > 0:
163+
return sandbox_status[0]
164+
return None
144165

145166
async def create_session(self, request: CreateSessionRequest) -> CreateBashSessionResponse:
146167
sandbox_actor = await self._deployment_service.async_ray_get_actor(request.sandbox_id)

rock/sandbox/service/deployment_service.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from rock.sandbox.sandbox_actor import SandboxActor
1515
from rock.sdk.common.exceptions import BadRequestRockError
1616
from rock.utils.format import parse_memory_size
17-
from rock.rocklet import __version__ as swe_version
18-
from rock.sandbox import __version__ as gateway_version
1917

2018
logger = init_logger(__name__)
2119

@@ -26,12 +24,12 @@ async def get_deployment(self, sandbox_id: str) -> AbstractDeployment:
2624
...
2725

2826
@abstractmethod
29-
async def submit(self, config: DeploymentConfig, user_info: dict) -> SandboxStartResponse:
27+
async def submit(self, config: DeploymentConfig, user_info: dict) -> SandboxInfo:
3028
"""Get status of sandbox."""
3129
...
3230

3331
@abstractmethod
34-
async def get_status(self, *args, **kwargs) -> SandboxStatusResponse:
32+
async def get_status(self, *args, **kwargs) -> SandboxInfo:
3533
"""Get status of sandbox."""
3634
...
3735

@@ -129,30 +127,17 @@ async def stop(self, sandbox_id: str):
129127
logger.info(f"run time stop over {sandbox_id}")
130128
ray.kill(actor)
131129

132-
async def get_status(self, sandbox_id: str) -> SandboxStatusResponse:
130+
async def get_status(self, sandbox_id: str) -> SandboxInfo:
133131
actor: SandboxActor = await self.async_ray_get_actor(sandbox_id)
134132
sandbox_info: SandboxInfo = await self.async_ray_get(actor.sandbox_info.remote())
135133
remote_status: ServiceStatus = await self.async_ray_get(actor.get_status.remote())
134+
sandbox_info["status"] = remote_status.phases
135+
sandbox_info["port_mapping"] = remote_status.get_port_mapping()
136136
alive = await self.async_ray_get(actor.is_alive.remote())
137+
sandbox_info["alive"] = alive.is_alive
137138
if alive.is_alive:
138139
sandbox_info["state"] = State.RUNNING
139-
return SandboxStatusResponse(
140-
sandbox_id=sandbox_id,
141-
status=remote_status.phases,
142-
port_mapping=remote_status.get_port_mapping(),
143-
host_name=sandbox_info.get("host_name"),
144-
host_ip=sandbox_info.get("host_ip"),
145-
is_alive=alive.is_alive,
146-
image=sandbox_info.get("image"),
147-
swe_rex_version=swe_version,
148-
gateway_version=gateway_version,
149-
user_id=sandbox_info.get("user_id"),
150-
experiment_id=sandbox_info.get("experiment_id"),
151-
namespace=sandbox_info.get("namespace"),
152-
cpus=sandbox_info.get("cpus"),
153-
memory=sandbox_info.get("memory"),
154-
state=sandbox_info.get("state"),
155-
)
140+
return sandbox_info
156141

157142
async def get_mount(self, sandbox_id: str):
158143
actor = await self.async_ray_get_actor(sandbox_id)

0 commit comments

Comments
 (0)