diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index de21b3e9e..38889656d 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -248,7 +248,7 @@ async def setup(self): # Setup scheduler # TODO: Add support for `log_stats` kv_cache_configs = await self.policy_worker.setup_kv_cache.call() - kv_cache_config = kv_cache_configs._values[0] + _, kv_cache_config = next(kv_cache_configs.items()) self.vllm_config.cache_config.num_gpu_blocks = kv_cache_config.num_blocks self.vllm_config.cache_config.num_cpu_blocks = 0 @@ -351,15 +351,14 @@ def preprocess_add_request(self, request: EngineCoreRequest) -> tuple[Request, i async def run(self): # TODO: add support for `iteration_stats` # TODO: move postprocessing out of loop to not block - parallel_config = self.vllm_config.parallel_config - output_rank = parallel_config.world_size - parallel_config.tensor_parallel_size self.running = True while self.running: scheduler_output = self.scheduler.schedule() worker_outputs = await self.policy_worker.execute_model.call( scheduler_output ) - worker_output = worker_outputs._values[output_rank] + # the results of `execute_model` is gathered on the driver rank (rank 0) + _, worker_output = next(worker_outputs.items()) outputs = self.scheduler.update_from_output(scheduler_output, worker_output) outputs = outputs.get(0) or EngineCoreOutputs() await asyncio.sleep(0) # Release control before processing outputs diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index b84e5eec7..85df718df 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -242,12 +242,9 @@ async def _process_single_request(self, request: ServiceRequest) -> bool: try: result = await endpoint_func.call(*request.args, **request.kwargs) # Unwrap ValueMesh if configured to return first rank result - if ( - self.return_first_rank_result - and hasattr(result, "_values") - and result._values - ): - result = result._values[0] + if self.return_first_rank_result: + _, first_result = next(result.items()) + result = first_result request.future.set_result(result) except ActorError as e: logger.warning(f"Got failure on replica {self.idx}. Error:\n{e}")