Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ async def setup(self):
# Setup scheduler
# TODO: Add support for `log_stats`
kv_cache_configs = await self.policy_worker.setup_kv_cache.call()
kv_cache_config = kv_cache_configs._values[0]
_, kv_cache_config = next(kv_cache_configs.items())
self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks
self.vllm_config.cache_config.num_cpu_blocks = 0

Expand Down Expand Up @@ -351,15 +351,14 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i
async def run(self):
# TODO: add support for `iteration_stats`
# TODO: move postprocessing out of loop to not block
parallel_config = self.vllm_config.parallel_config
output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size
self.running = True
while self.running:
scheduler_output = self.scheduler.schedule()
worker_outputs = await self.policy_worker.execute_model.call(
scheduler_output
)
worker_output = worker_outputs._values[output_rank]
# the results of `execute_model` is gathered on the driver rank (rank 0)
_, worker_output = next(worker_outputs.items())
outputs = self.scheduler.update_from_output(scheduler_output, worker_output)
outputs = outputs.get(0) or EngineCoreOutputs()
await asyncio.sleep(0) # Release control before processing outputs
Expand Down
9 changes: 3 additions & 6 deletions src/forge/controller/service/replica.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,12 +242,9 @@ async def _process_single_request(self, request: ServiceRequest) -> bool:
try:
result = await endpoint_func.call(*request.args, **request.kwargs)
# Unwrap ValueMesh if configured to return first rank result
if (
self.return_first_rank_result
and hasattr(result, "_values")
and result._values
):
result = result._values[0]
if self.return_first_rank_result:
_, first_result = next(result.items())
result = first_result
request.future.set_result(result)
except ActorError as e:
logger.warning(f"Got failure on replica {self.idx}. Error:\n{e}")
Expand Down
Loading