diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 15642272c..3e7b25b7d 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -278,16 +278,16 @@ async def main(cfg: DictConfig): # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 - pad_id = await dataloader.pad_token.choose() + pad_id = await dataloader.pad_token.route() while True: - sample = await dataloader.sample.choose() + sample = await dataloader.sample.route() if sample is None: print("Dataloader is empty, exiting continuous rollout") return prompt, target = sample["request"], sample["target"] - responses = await policy.generate.choose(prompt) + responses = await policy.generate.route(prompt) # TODO: this shall be part of the responses metadata instead of a separate call - version = await policy.get_version.choose() + version = await policy.get_version.route() group = Group.new_group( group_id=rollout_count, group_size=group_size, @@ -311,29 +311,29 @@ async def continuous_rollouts(): episode.response = response.text input_ids[i, :max_req_tokens] = episode.request_tensor input_ids[i, max_req_tokens:] = episode.response_tensor - episode.reward = await reward_actor.evaluate_response.choose( + episode.reward = await reward_actor.evaluate_response.route( prompt=prompt, response=response.text, target=target ) # Calculate reference logprobs - ref_logits = await ref_model.forward.choose(input_ids) + ref_logits = await ref_model.forward.route(input_ids) ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:]) for i, episode in enumerate(group.episodes): episode.ref_logprobs = ref_logprobs[i] del ref_logits, ref_logprobs, input_ids # Calculate advantages and add to replay buffer - advantages = await compute_advantages.compute.choose(group) + advantages = await compute_advantages.compute.route(group) for episode, advantage in zip(group.episodes, advantages): episode.advantage = advantage - await replay_buffer.add.choose(episode) + await replay_buffer.add.route(episode) # Log metrics avg_response_len = ( sum(len(e.response_tokens) for e in group.episodes) / group_size ) mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count) - buffer_size = await replay_buffer._numel.choose() + buffer_size = await replay_buffer._numel.route() mlogger.log("buffer_size/rollout", buffer_size, rollout_count) avg_reward = sum(e.reward for e in group.episodes) / group_size mlogger.log("avg_reward/rollout", avg_reward, rollout_count) @@ -343,16 +343,16 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 while True: - batch = await replay_buffer.sample.choose(curr_policy_version=training_step) + batch = await replay_buffer.sample.route(curr_policy_version=training_step) if batch is None: await asyncio.sleep(0.1) else: inputs, targets = batch - loss = await trainer.train_step.choose(inputs, targets) + loss = await trainer.train_step.route(inputs, targets) training_step += 1 mlogger.log("loss/training_step", loss, training_step) - await trainer.push_weights.call(training_step) - await policy.update_weights.call(training_step) + await trainer.push_weights.fanout(training_step) + await policy.update_weights.fanout(training_step) print("Starting GRPO training loops...") # TODO: Start multiple rollouts once all serivces support it diff --git a/apps/rl/main.py b/apps/rl/main.py index 9f7314f53..49084b50b 100644 --- a/apps/rl/main.py +++ b/apps/rl/main.py @@ -161,11 +161,11 @@ async def run(cfg: DictConfig): ref_logprobs=torch.randn((256,), generator=g), advantage=torch.randn((1,), generator=g), ) - await replay_buffer.add.choose(e) + await replay_buffer.add.route(e) print("Train step...") - inputs, targets = await replay_buffer.sample.choose(curr_policy_version=0) - outputs = await trainer.train_step.choose(inputs, targets) + inputs, targets = await replay_buffer.sample.route(curr_policy_version=0) + outputs = await trainer.train_step.route(inputs, targets) print("Loss: ", outputs["loss"]) print("Shutting down...") diff --git a/apps/sft_v2/main.py b/apps/sft_v2/main.py index 773effdfd..a2f4f1479 100644 --- a/apps/sft_v2/main.py +++ b/apps/sft_v2/main.py @@ -289,13 +289,13 @@ async def run(cfg: DictConfig) -> None: ) logging.info("Created recipe, running setup.") - await recipe.setup.call() + await recipe.setup.fanout() logging.info("Recipe has been setup. Training now.") - await recipe.train.call() + await recipe.train.fanout() logging.info("Done training. Clean up") - await recipe.cleanup.call() + await recipe.cleanup.fanout() await recipe.mesh.stop() logging.info("All done!") diff --git a/apps/toy_rl/sumdigits.py b/apps/toy_rl/sumdigits.py index 4ef0833c2..4347f2549 100644 --- a/apps/toy_rl/sumdigits.py +++ b/apps/toy_rl/sumdigits.py @@ -464,16 +464,16 @@ async def main(cfg: DictConfig): # ---- Core RL loops ---- # async def continuous_rollouts(): rollout_count = 0 - pad_id = await dataloader.pad_token.choose() + pad_id = await dataloader.pad_token.route() while True: # Pass rollout_count for curriculum learning - sample = await dataloader.sample.choose(rollout_count) + sample = await dataloader.sample.route(rollout_count) if sample is None: print("Dataloader is empty, exiting continuous rollout") return prompt, target = sample["request"], sample["target"] - responses = await policy.generate.choose(prompt) - version = await policy.get_version.choose() + responses = await policy.generate.route(prompt) + version = await policy.get_version.route() group = Group.new_group( group_id=rollout_count, group_size=group_size, @@ -491,13 +491,13 @@ async def continuous_rollouts(): episode.response_tokens = response.token_ids episode.response = response.text episode.response_logprobs = response.logprobs - episode.ref_logprobs = await ref_model.forward.choose(episode) - episode.reward = await reward_actor.evaluate_response.choose( + episode.ref_logprobs = await ref_model.forward.route(episode) + episode.reward = await reward_actor.evaluate_response.route( prompt=prompt, response=response.text, target=target ) episode.advantage = episode.reward # simple case for now for episode in group.episodes: - await replay_buffer.add.choose(episode) + await replay_buffer.add.route(episode) avg_response_len = ( sum(len(e.response_tokens) for e in group.episodes) / group_size ) @@ -510,18 +510,18 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 while True: - batch = await replay_buffer.sample.choose(curr_policy_version=training_step) + batch = await replay_buffer.sample.route(curr_policy_version=training_step) if batch is None: await asyncio.sleep(0.1) else: - loss = await trainer.train_step.choose(batch[0]) + loss = await trainer.train_step.route(batch[0]) training_step += 1 mlogger.log("loss/training_step", loss, training_step) print(f"loss/training_step: {loss} at training step {training_step}") - await trainer.push_weights.call(training_step) - await policy.update_weights.call(training_step) + await trainer.push_weights.fanout(training_step) + await policy.update_weights.fanout(training_step) # NOTE: hard-coded to be on-policy for faster convergence - await replay_buffer.clear.call() + await replay_buffer.clear.fanout() print("Starting training loop.") # TODO: Start multiple rollouts once all serivces support it diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 6ba1bbbaf..6ba4c76c0 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -39,7 +39,7 @@ async def run(cfg: DictConfig): n = 100 start = time.time() response_outputs: list[Completion] = await asyncio.gather( - *[policy.generate.choose(prompt=prompt) for _ in range(n)] + *[policy.generate.route(prompt=prompt) for _ in range(n)] ) end = time.time() diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index a70ec8ad2..eb6f0ddd7 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -78,29 +78,25 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): class ServiceEndpoint(Generic[P, R]): - """An endpoint object specific to services. - - This loosely mimics the Endpoint APIs exposed in Monarch, with - a few key differences: - - Only choose and call are retained (dropping stream and call_one) - - Call returns a list directly rather than a ValueMesh. - - These changes are made with Forge use cases in mind, but can - certainly be expanded/adapted in the future. + """ + This extends Monarch's actor APIs for service endpoints. + - `route(*args, **kwargs)`: Routes the request to a single replica. + - `fanout(*args, **kwargs)`: Broadcasts the request to all healthy replicas. + Monarch's native actor APIs do not apply for services. """ def __init__(self, service, endpoint_name: str): self.service = service self.endpoint_name = endpoint_name - async def choose(self, *args: P.args, **kwargs: P.kwargs) -> R: + async def route(self, *args: P.args, **kwargs: P.kwargs) -> R: """Chooses a replica to call based on context and load balancing strategy.""" # Extract sess_id from kwargs if present sess_id = kwargs.pop("sess_id", None) return await self.service._call(sess_id, self.endpoint_name, *args, **kwargs) - async def call(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: + async def fanout(self, *args: P.args, **kwargs: P.kwargs) -> List[R]: """Broadcasts a request to all healthy replicas and returns the results as a list.""" result = await self.service.call_all(self.endpoint_name, *args, **kwargs) return result diff --git a/src/forge/controller/service/service.py b/src/forge/controller/service/service.py index 66c9c234e..0b655fb6a 100644 --- a/src/forge/controller/service/service.py +++ b/src/forge/controller/service/service.py @@ -75,7 +75,6 @@ class Service: _replicas: List of managed replica instances _active_sessions: Currently active sessions _metrics: Aggregated service and replica metrics - _endpoints: Dynamically registered actor endpoints """ def __init__( @@ -612,7 +611,6 @@ class ServiceActor(Actor): _replicas: List of managed replica instances _active_sessions: Currently active sessions _metrics: Aggregated service and replica metrics - _endpoints: Dynamically registered actor endpoints """ def __init__(self, cfg: ServiceConfig, actor_def, actor_kwargs: dict): diff --git a/tests/integration_tests/test_policy_update.py b/tests/integration_tests/test_policy_update.py index d265df206..c737078b7 100644 --- a/tests/integration_tests/test_policy_update.py +++ b/tests/integration_tests/test_policy_update.py @@ -257,7 +257,7 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg): procs=worker_size, with_gpus=True, num_replicas=1 ).as_service(**trainer_cfg) - await rl_trainer.push_weights.choose(policy_version=0) + await rl_trainer.push_weights.route(policy_version=0) # 3. Policy pull weights policy_config, service_config = get_configs( worker_size=worker_size, tp_size=worker_size, model_name=self.model @@ -265,9 +265,9 @@ async def test_policy_update_single(self, expected_sd, trainer_cfg): policy = await Policy.options(**asdict(service_config)).as_service( **policy_config ) - await policy.update_weights.call() + await policy.update_weights.fanout() # 4. Validate weights - loaded_state_dict = await policy._get_model_params.choose() + loaded_state_dict = await policy._get_model_params.route() validate_loaded_tensors_equals_original( loaded_state_dict, expected_sd, tensor_parallel_size=1, rank=0 ) @@ -297,7 +297,7 @@ async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): procs=trainer_worker_size, with_gpus=True, num_replicas=1 ).as_service(**trainer_cfg_tp) - await rl_trainer.push_weights.call(policy_version=0) + await rl_trainer.push_weights.fanout(policy_version=0) # 3. Policy pull weights policy_config, service_config = get_configs( worker_size=policy_worker_size, tp_size=tp_size, model_name=self.model @@ -305,13 +305,13 @@ async def test_policy_update_tp(self, expected_sd, trainer_cfg_tp): policy = await Policy.options(**asdict(service_config)).as_service( **policy_config ) - await policy.update_weights.call() + await policy.update_weights.fanout() # validate loaded shard of each worker againt manually calculated shard (expected shard). # 4. Validate weight shards. We compare vLLM loades shard content with # Directly loaded HF shard content. - sharded_state_dicts = await policy._get_model_params.call() + sharded_state_dicts = await policy._get_model_params.fanout() validate_loaded_tensors_equals_original( sharded_state_dicts[0][0], expected_sd, tensor_parallel_size=tp_size, rank=0 ) diff --git a/tests/unit_tests/test_replay_buffer.py b/tests/unit_tests/test_replay_buffer.py index 4463c3f2c..f6c5bc574 100644 --- a/tests/unit_tests/test_replay_buffer.py +++ b/tests/unit_tests/test_replay_buffer.py @@ -11,15 +11,12 @@ from forge.actors.replay_buffer import ReplayBuffer from forge.types import Trajectory -from monarch.actor import proc_mesh - class TestReplayBuffer: @pytest_asyncio.fixture async def replay_buffer(self) -> ReplayBuffer: - mesh = await proc_mesh(gpus=1) - replay_buffer = await mesh.spawn( - "replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1 + replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=True).as_actor( + batch_size=2, max_policy_age=1 ) await replay_buffer.setup.call() return replay_buffer @@ -112,10 +109,9 @@ async def test_sample_with_evictions(self, replay_buffer) -> None: @pytest.mark.asyncio async def test_sample_dp_size(self) -> None: """Test that len(samples) == dp_size when sampling.""" - mesh = await proc_mesh(gpus=1) # Create replay buffer with dp_size=3 - replay_buffer = await mesh.spawn( - "replay_buffer", ReplayBuffer, batch_size=2, max_policy_age=1, dp_size=3 + replay_buffer = await ReplayBuffer.options(procs=1, with_gpus=True).as_actor( + batch_size=2, max_policy_age=1, dp_size=3 ) await replay_buffer.setup.call() diff --git a/tests/unit_tests/test_service.py b/tests/unit_tests/test_service.py index 968f1eea2..612e4ff42 100644 --- a/tests/unit_tests/test_service.py +++ b/tests/unit_tests/test_service.py @@ -163,7 +163,7 @@ async def test_service_with_kwargs_config(): assert cfg.num_replicas == 4 assert cfg.procs == 1 assert cfg.health_poll_rate == 0.5 - assert await service.value.choose() == 20 + assert await service.value.route() == 20 finally: await service.shutdown() @@ -177,7 +177,7 @@ async def test_service_default_config(): cfg = service._service._cfg assert cfg.num_replicas == 1 assert cfg.procs == 1 - assert await service.value.choose() == 10 + assert await service.value.route() == 10 finally: await service.shutdown() @@ -203,8 +203,8 @@ async def test_multiple_services_isolated_configs(): assert cfg1 is not cfg2 # configs should not be the same object # Check actor values - val1 = await service1.value.choose() - val2 = await service2.value.choose() + val1 = await service1.value.route() + val2 = await service2.value.route() assert val1 == 10 assert val2 == 20 @@ -231,8 +231,8 @@ async def test_basic_service_operations(): assert isinstance(session1, str) # Test endpoint calls - await service.incr.choose(sess_id=session1) - result = await service.value.choose(sess_id=session1) + await service.incr.route(sess_id=session1) + result = await service.value.route(sess_id=session1) assert result == 1 # Test session mapping @@ -255,9 +255,9 @@ async def test_sessionless_calls(): service = await Counter.options(procs=1, num_replicas=2).as_service(v=0) try: # Test sessionless calls - await service.incr.choose() - await service.incr.choose() - result = await service.value.choose() + await service.incr.route() + await service.incr.route() + result = await service.value.route() assert result is not None # No sessions should be created @@ -274,7 +274,7 @@ async def test_sessionless_calls(): assert total_requests == 3 # Users should be able to call endpoint with just args - result = await service.add_to_value.choose(5, multiplier=2) + result = await service.add_to_value.route(5, multiplier=2) assert result == 11 # 1 + 10 finally: @@ -289,18 +289,18 @@ async def test_session_context_manager(): try: # Test context manager usage async with service.session(): - await service.incr.choose() - await service.incr.choose() - result = await service.value.choose() + await service.incr.route() + await service.incr.route() + result = await service.value.route() assert result == 2 # Test sequential context managers to avoid interference async def worker(increments: int): async with service.session(): - initial = await service.value.choose() + initial = await service.value.route() for _ in range(increments): - await service.incr.choose() - final = await service.value.choose() + await service.incr.route() + final = await service.value.route() return final - initial # Run sessions sequentially to avoid concurrent modification @@ -339,12 +339,12 @@ async def test_recovery_state_transitions(): # Create session and make a successful call session = await service.start_session() - await service.incr.choose(sess_id=session) - result = await service.value.choose(sess_id=session) + await service.incr.route(sess_id=session) + result = await service.value.route(sess_id=session) assert result == 1 # Cause failure - this should transition to RECOVERING - error_result = await service.fail_me.choose(sess_id=session) + error_result = await service.fail_me.route(sess_id=session) assert isinstance(error_result, RuntimeError) # Replica should now be in RECOVERING state @@ -383,8 +383,8 @@ async def test_recovery_state_transitions(): # Test that we can make new calls after recovery new_session = await service.start_session() - await service.incr.choose(sess_id=new_session) - result = await service.value.choose(sess_id=new_session) + await service.incr.route(sess_id=new_session) + result = await service.value.route(sess_id=new_session) assert ( result is not None ) # Should get a result (counter starts at 0 in new actor) @@ -412,13 +412,13 @@ async def test_replica_failure_and_recovery(): try: # Create session and cause failure session = await service.start_session() - await service.incr.choose(sess_id=session) + await service.incr.route(sess_id=session) state = await service._get_internal_state() original_replica_idx = state["session_replica_map"][session] # Cause failure - error_result = await service.fail_me.choose(sess_id=session) + error_result = await service.fail_me.route(sess_id=session) assert isinstance(error_result, RuntimeError) # Replica should be marked as failed @@ -427,14 +427,14 @@ async def test_replica_failure_and_recovery(): assert not failed_replica["healthy"] # Session should be reassigned on next call - await service.incr.choose(sess_id=session) + await service.incr.route(sess_id=session) state = await service._get_internal_state() new_replica_idx = state["session_replica_map"][session] assert new_replica_idx != original_replica_idx # New sessions should avoid failed replica new_session = await service.start_session() - await service.incr.choose(sess_id=new_session) + await service.incr.route(sess_id=new_session) state = await service._get_internal_state() assigned_replica = state["replicas"][state["session_replica_map"][new_session]] assert assigned_replica["healthy"] @@ -457,12 +457,12 @@ async def test_metrics_collection(): session1 = await service.start_session() session2 = await service.start_session() - await service.incr.choose(sess_id=session1) - await service.incr.choose(sess_id=session1) - await service.incr.choose(sess_id=session2) + await service.incr.route(sess_id=session1) + await service.incr.route(sess_id=session1) + await service.incr.route(sess_id=session2) # Test failure metrics - error_result = await service.fail_me.choose(sess_id=session1) + error_result = await service.fail_me.route(sess_id=session1) assert isinstance(error_result, RuntimeError) # Get metrics @@ -508,20 +508,20 @@ async def test_session_stickiness(): session = await service.start_session() # Make multiple calls - await service.incr.choose(sess_id=session) - await service.incr.choose(sess_id=session) - await service.incr.choose(sess_id=session) + await service.incr.route(sess_id=session) + await service.incr.route(sess_id=session) + await service.incr.route(sess_id=session) # Should always route to same replica state = await service._get_internal_state() replica_idx = state["session_replica_map"][session] - await service.incr.choose(sess_id=session) + await service.incr.route(sess_id=session) state = await service._get_internal_state() assert state["session_replica_map"][session] == replica_idx # Verify counter was incremented correctly - result = await service.value.choose(sess_id=session) + result = await service.value.route(sess_id=session) assert result == 4 finally: @@ -537,20 +537,20 @@ async def test_load_balancing_multiple_sessions(): try: # Create sessions with some load to trigger distribution session1 = await service.start_session() - await service.incr.choose(sess_id=session1) # Load replica 0 + await service.incr.route(sess_id=session1) # Load replica 0 session2 = await service.start_session() - await service.incr.choose( + await service.incr.route( sess_id=session2 ) # Should go to replica 1 (least loaded) session3 = await service.start_session() - await service.incr.choose( + await service.incr.route( sess_id=session3 ) # Should go to replica 0 or 1 based on load session4 = await service.start_session() - await service.incr.choose(sess_id=session4) # Should balance the load + await service.incr.route(sess_id=session4) # Should balance the load # Check that sessions are distributed (may not be perfectly even due to least-loaded logic) state = await service._get_internal_state() @@ -588,10 +588,10 @@ async def test_concurrent_operations(): # Concurrent operations tasks = [ - service.incr.choose(sess_id=session), # Session call - service.incr.choose(sess_id=session), # Session call - service.incr.choose(), # Sessionless call - service.incr.choose(), # Sessionless call + service.incr.route(sess_id=session), # Session call + service.incr.route(sess_id=session), # Session call + service.incr.route(), # Sessionless call + service.incr.route(), # Sessionless call ] await asyncio.gather(*tasks) @@ -624,7 +624,7 @@ async def test_broadcast_call_basic(): try: # Test broadcast call to all replicas - results = await service.incr.call() + results = await service.incr.fanout() # Should get results from all healthy replicas assert isinstance(results, list) @@ -634,7 +634,7 @@ async def test_broadcast_call_basic(): assert all(result is None for result in results) # Test getting values from all replicas - values = await service.value.call() + values = await service.value.fanout() assert isinstance(values, list) assert len(values) == 3 @@ -655,7 +655,7 @@ async def test_broadcast_call_with_failed_replica(): # First, cause one replica to fail by calling fail_me on a specific session session = await service.start_session() try: - await service.fail_me.choose(sess_id=session) + await service.fail_me.route(sess_id=session) except RuntimeError: pass # Expected failure @@ -663,7 +663,7 @@ async def test_broadcast_call_with_failed_replica(): await asyncio.sleep(0.1) # Now test broadcast call - should only hit healthy replicas - results = await service.incr.call() + results = await service.incr.fanout() # Should get results from healthy replicas only assert isinstance(results, list) @@ -673,7 +673,7 @@ async def test_broadcast_call_with_failed_replica(): assert len(results) == healthy_count # Get values from all healthy replicas - values = await service.value.call() + values = await service.value.fanout() assert len(values) == healthy_count # All healthy replicas should have incremented to 1 @@ -685,26 +685,26 @@ async def test_broadcast_call_with_failed_replica(): @pytest.mark.timeout(10) @pytest.mark.asyncio -async def test_broadcast_call_vs_choose(): - """Test that broadcast call hits all replicas while choose hits only one.""" +async def test_broadcast_fanout_vs_route(): + """Test that broadcast fanout hits all replicas while route hits only one.""" service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) try: # Use broadcast call to increment all replicas - await service.incr.call() + await service.incr.fanout() # Get values from all replicas - values_after_broadcast = await service.value.call() + values_after_broadcast = await service.value.fanout() assert len(values_after_broadcast) == 3 assert all(value == 1 for value in values_after_broadcast) - # Use choose to increment only one replica - await service.incr.choose() + # Use route to increment only one replica + await service.incr.route() # Get values again - one replica should be at 2, others at 1 - values_after_choose = await service.value.call() - assert len(values_after_choose) == 3 - assert sorted(values_after_choose) == [1, 1, 2] # One replica incremented twice + values_after_route = await service.value.fanout() + assert len(values_after_route) == 3 + assert sorted(values_after_route) == [1, 1, 2] # One replica incremented twice # Verify metrics show the correct number of requests metrics = await service.get_metrics_summary() @@ -712,7 +712,8 @@ async def test_broadcast_call_vs_choose(): replica_metrics["total_requests"] for replica_metrics in metrics["replicas"].values() ) - # incr.call() (3 requests) + value.call() (3 requests) + incr.choose() (1 request) + value.call() (3 requests) = 10 total + # incr.fanout() (3 requests) + value.fanout() (3 requests) + + # incr.route() (1 request) + value.fanout() (3 requests) = 10 total assert total_requests == 10 finally: @@ -759,11 +760,11 @@ async def test_round_robin_router_distribution(): service = await Counter.options(procs=1, num_replicas=3).as_service(v=0) try: - # Make multiple sessionless calls using choose() + # Make multiple sessionless calls using route() results = [] for _ in range(6): - await service.incr.choose() - values = await service.value.call() + await service.incr.route() + values = await service.value.fanout() print(values) results.append(values) print("results: ", results) @@ -789,12 +790,12 @@ async def test_session_router_assigns_and_updates_session_map_in_service(): try: # First call with sess_id -> assign a replica - await service.incr.choose(sess_id="sess1") - values1 = await service.value.call() + await service.incr.route(sess_id="sess1") + values1 = await service.value.fanout() # Second call with same sess_id -> must hit same replica - await service.incr.choose(sess_id="sess1") - values2 = await service.value.call() + await service.incr.route(sess_id="sess1") + values2 = await service.value.fanout() # Difference should only be on one replica (sticky session) diffs = [v2 - v1 for v1, v2 in zip(values1, values2)]