55import random
66import time
77from dataclasses import asdict , dataclass
8- from typing import Any , Dict , Iterable , List , Tuple
8+ from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Tuple , cast
99
10- import ray
1110import grpc
11+ import ray
12+ from envd .generated import tools_pb2 as tools_pb2_module
13+ from envd .generated import tools_pb2_grpc as tools_pb2_grpc_module
1214from google .protobuf .json_format import MessageToDict
1315
14- from envd .generated import tools_pb2 , tools_pb2_grpc
1516from .config import SamplerConfig , StorageConfig
1617from .model import create_sampler , ensure_json_plan
1718from .storage import create_storage , write_rollout_records
1819
20+ tools_pb2 = cast (Any , tools_pb2_module )
21+ tools_pb2_grpc = cast (Any , tools_pb2_grpc_module )
22+ ray = cast (Any , ray )
23+
24+ if TYPE_CHECKING :
25+ from ray .actor import ActorHandle
26+ else : # pragma: no cover - typing fallback
27+ ActorHandle = Any
28+
1929
2030@dataclass (slots = True )
2131class ToolSpec :
@@ -39,15 +49,14 @@ def build_request(tool: str, args: Dict[str, Any]) -> Any:
3949 return spec .request_cls (** args )
4050
4151
42- def stub_method (stub : tools_pb2_grpc . ToolsStub , tool : str ):
52+ def stub_method (stub : Any , tool : str ):
4353 spec = TOOL_SPECS .get (tool )
4454 if not spec :
4555 raise ValueError (f"unknown tool: { tool } " )
4656 return getattr (stub , spec .attr )
4757
4858
49- @ray .remote (num_cpus = 0.05 )
50- class EnvClient :
59+ class EnvClientImpl :
5160 def __init__ (self , host : str ):
5261 self ._host = host
5362 self ._channel = grpc .aio .insecure_channel (host )
@@ -58,9 +67,9 @@ async def call(self, tool: str, args: Dict[str, Any], *, timeout: float | None =
5867 rpc = stub_method (self ._stub , tool )
5968 try :
6069 response = await rpc (request , timeout = timeout )
61- payload = MessageToDict (response , preserving_proto_field_name = True , including_default_value_fields = True )
70+ payload = MessageToDict (response , preserving_proto_field_name = True )
6271 return {"tool" : tool , "ok" : True , "response" : payload }
63- except grpc .aio .AioRpcError as exc : # noqa: D
72+ except grpc .aio .AioRpcError as exc : # noqa: D401
6473 return {
6574 "tool" : tool ,
6675 "ok" : False ,
@@ -71,9 +80,12 @@ async def call(self, tool: str, args: Dict[str, Any], *, timeout: float | None =
7180 async def close (self ) -> None :
7281 await self ._channel .close ()
7382
83+ @classmethod
84+ def remote (cls , host : str ) -> Any : # pragma: no cover - typing helper
85+ return cls (host )
7486
75- @ ray . remote ( num_cpus = 0.1 )
76- class ModelSampler :
87+
88+ class ModelSamplerImpl :
7789 def __init__ (self , cfg_dict : Dict [str , Any ]):
7890 self ._cfg = SamplerConfig (** cfg_dict )
7991 self ._backend = create_sampler (self ._cfg )
@@ -83,9 +95,12 @@ async def sample(self, prompt: str) -> str:
8395 normalized = await ensure_json_plan (plan )
8496 return normalized
8597
98+ @classmethod
99+ def remote (cls , cfg_dict : Dict [str , Any ]) -> Any : # pragma: no cover - typing helper
100+ return cls (cfg_dict )
101+
86102
87- @ray .remote (num_cpus = 0.2 )
88- class Sampler :
103+ class SamplerImpl :
89104 def __init__ (self , env_hosts : Iterable [str ], model_ref ):
90105 self ._envs = [EnvClient .remote (host ) for host in env_hosts ]
91106 self ._model = model_ref
@@ -130,10 +145,13 @@ async def rollout(self, prompt: str, *, budget_s: float = 45.0, group_timeout_s:
130145 latency = time .perf_counter () - started
131146 return {"trajectory" : trajectory , "latency_s" : latency }
132147
148+ @classmethod
149+ def remote (cls , env_hosts : Iterable [str ], model_ref : Any ) -> Any : # pragma: no cover - typing helper
150+ return cls (env_hosts , model_ref )
133151
134- @ ray . remote ( num_cpus = 0.05 )
135- class Controller :
136- def __init__ (self , sampler_refs : Iterable [ray . actor . ActorHandle ], storage_cfg : Dict [str , Any ]):
152+
153+ class ControllerImpl :
154+ def __init__ (self , sampler_refs : Iterable [ActorHandle ], storage_cfg : Dict [str , Any ]):
137155 self ._samplers = list (sampler_refs )
138156 self ._storage = create_storage (StorageConfig (** storage_cfg ))
139157
@@ -169,6 +187,22 @@ async def batch_rollouts(self, prompts: List[str], *, replicas: int = 3, timeout
169187 records = await write_rollout_records (self ._storage , pairs )
170188 return records
171189
190+ @classmethod
191+ def remote (cls , sampler_refs : Iterable [ActorHandle ], storage_cfg : Dict [str , Any ]) -> Any : # pragma: no cover
192+ return cls (sampler_refs , storage_cfg )
193+
194+
195+ if TYPE_CHECKING :
196+ EnvClient = EnvClientImpl
197+ ModelSampler = ModelSamplerImpl
198+ Sampler = SamplerImpl
199+ Controller = ControllerImpl
200+ else : # pragma: no cover - runtime actor binding
201+ EnvClient = ray .remote (num_cpus = 0.05 )(EnvClientImpl )
202+ ModelSampler = ray .remote (num_cpus = 0.1 )(ModelSamplerImpl )
203+ Sampler = ray .remote (num_cpus = 0.2 )(SamplerImpl )
204+ Controller = ray .remote (num_cpus = 0.05 )(ControllerImpl )
205+
172206
173207def bootstrap_ray (env_hosts : List [str ], * , num_samplers : int = 2 ):
174208 if not ray .is_initialized ():
0 commit comments