Skip to content

Commit a8a3290

Browse files
authored
[data] feat: TransferQueue - Support AgentLoop performance metrics & minor fix (verl-project#4289)
### What does this PR do? 1. Support performance metrics statistics that requires tensor data 2. Add stand-alone config structure for TransferQueue 3. Modify TransferQueue initialization process to suit for multiple backends 4. Fix `create_transferqueue_client` usage 5. Unify some function names 6. Add TODO ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Signed-off-by: 0oshowero0 <o0shower0o@outlook.com>
1 parent e5243f5 commit a8a3290

File tree

9 files changed

+140
-120
lines changed

9 files changed

+140
-120
lines changed

recipe/transfer_queue/agent_loop.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import asyncio
15+
1416
import numpy as np
1517
import ray
1618
from transfer_queue import BatchMeta
@@ -65,19 +67,33 @@ def _performance_metrics(self, metrics: list[list[dict[str, str]]], output: Batc
6567
timing["agent_loop/tool_calls/max"] = t_tool_calls.max()
6668
timing["agent_loop/tool_calls/mean"] = t_tool_calls.mean()
6769

68-
# TODO (TQ): pass tq info throughout AgentLoop so we can retrieve tensor for these metrics
70+
# TODO (TQ): initialize tq during init when enable TQ switch is stable
71+
tq_client = self._create_transferqueue_client()
6972
# batch sequence generation is bounded by the slowest sample
70-
# slowest = np.argmax(t_generate_sequences + t_tool_calls)
71-
# attention_mask = output.extra_info.pop("attention_mask_perf")[slowest]
72-
# prompt_length = output.extra_info.pop("prompts_perf").shape[1]
73-
# timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest]
74-
# timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest]
75-
# timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item()
76-
# timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item()
73+
slowest = np.argmax(t_generate_sequences + t_tool_calls)
74+
attention_mask = asyncio.run(tq_client.async_get_data(output[slowest]))["attention_mask"]
75+
prompt_length = output.samples[0].fields["prompts"].shape[0]
76+
timing["agent_loop/slowest/generate_sequences"] = t_generate_sequences[slowest]
77+
timing["agent_loop/slowest/tool_calls"] = t_tool_calls[slowest]
78+
timing["agent_loop/slowest/prompt_length"] = attention_mask[:prompt_length].sum().item()
79+
timing["agent_loop/slowest/response_length"] = attention_mask[prompt_length:].sum().item()
7780

7881
return timing
7982

80-
def create_transferqueue_client(self, controller_info, config):
81-
ray.get(
82-
[worker.create_transferqueue_client.remote(controller_info, config) for worker in self.agent_loop_workers]
83+
def create_transferqueue_client_for_workers(self):
84+
# TODO (TQ): initialize tq during worker init when enable TQ switch is stable
85+
ray.get([worker.create_transferqueue_client.remote() for worker in self.agent_loop_workers])
86+
87+
def _create_transferqueue_client(self):
88+
"""Create a client for data system (TransferQueue)."""
89+
from verl.single_controller.ray.base import get_random_string
90+
from verl.utils.transferqueue_utils import create_transferqueue_client
91+
92+
client_name = get_random_string(length=6)
93+
94+
tq_client = create_transferqueue_client(
95+
client_id=f"AgentLoopManager_{client_name}",
96+
config=self.config.transfer_queue,
8397
)
98+
99+
return tq_client

recipe/transfer_queue/config/transfer_queue_ppo_megatron_trainer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ defaults:
99
# config for TransferQueue
1010
transfer_queue:
1111
enable: True
12+
num_global_batch: 1
13+
storage_backend: AsyncSimpleStorageManager
14+
num_data_storage_units: 8

recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,6 @@ defaults:
99
# config for TransferQueue
1010
transfer_queue:
1111
enable: True
12+
num_global_batch: 1
13+
storage_backend: AsyncSimpleStorageManager
14+
num_data_storage_units: 8

recipe/transfer_queue/ray_trainer.py

Lines changed: 89 additions & 90 deletions
Large diffs are not rendered by default.

recipe/transfer_queue/run_qwen3-8b_transferqueue.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,5 @@ python3 -m recipe.transfer_queue.main_ppo \
6363
trainer.total_epochs=15 \
6464
trainer.total_training_steps=2 \
6565
trainer.val_before_train=False \
66-
+trainer.num_global_batch=1 \
67-
+trainer.num_data_storage_units=8 \
6866
2>&1 | tee "$log_file"
6967
echo "Finished, log is saved in: $log_file"

tests/special_e2e/run_transferqueue.sh

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,6 @@ common_params=(
122122
trainer.total_training_steps=2
123123
trainer.total_epochs=15
124124
trainer.val_before_train=True
125-
+trainer.num_global_batch=1
126-
+trainer.num_data_storage_units=8
127125
)
128126

129127
if [ "${ACTOR_STRATEGY}" == "fsdp" ]; then

verl/experimental/agent_loop/agent_loop.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -633,16 +633,18 @@ def _postprocess(self, inputs: list[_InternalAgentLoopOutput]) -> DataProto:
633633
meta_info={"metrics": metrics, "reward_extra_keys": reward_extra_keys},
634634
)
635635

636-
def create_transferqueue_client(self, controller_info, role):
637-
"""Create a client for data system(transfer queue)."""
636+
def create_transferqueue_client(
637+
self,
638+
):
639+
"""Create a client for data system (TransferQueue)."""
638640
from verl.single_controller.ray.base import get_random_string
639641
from verl.utils.transferqueue_utils import create_transferqueue_client
640642

641643
client_name = get_random_string(length=6)
642-
create_transferqueue_client(
643-
client_id=f"{role}_worker_{client_name}",
644-
controller_info=controller_info,
645-
config=self.config,
644+
645+
self.tq_client = create_transferqueue_client(
646+
client_id=f"AgentLoopWorker_{client_name}",
647+
config=self.config.transfer_queue,
646648
)
647649

648650

verl/single_controller/base/worker.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,12 @@ def set_dispatch_collect(self, mesh_name: str, dispatch_dp_rank: dict[str, int],
163163
self.__collect_dp_rank[mesh_name] = is_collect
164164

165165
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=True)
166-
def create_transferqueue_client(self, controller_info, config):
166+
def create_transferqueue_client(self, config):
167167
from verl.utils.transferqueue_utils import create_transferqueue_client
168168

169169
create_transferqueue_client(
170170
client_id=f"worker_{self.rank}",
171-
controller_info=controller_info,
172-
config=config,
171+
config=config.transfer_queue,
173172
)
174173

175174
@classmethod

verl/utils/transferqueue_utils.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
from transfer_queue import (
2626
AsyncTransferQueueClient,
2727
BatchMeta,
28-
ZMQServerInfo,
2928
)
3029

3130
except ImportError:
@@ -44,18 +43,21 @@ class BatchMeta:
4443

4544
def create_transferqueue_client(
4645
client_id: str,
47-
controller_info: "ZMQServerInfo",
4846
config,
49-
) -> None:
47+
) -> "AsyncTransferQueueClient":
5048
global _TRANSFER_QUEUE_CLIENT
51-
_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, controller_info)
52-
_TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type="AsyncSimpleStorageManager", config=config)
49+
if _TRANSFER_QUEUE_CLIENT is None:
50+
_TRANSFER_QUEUE_CLIENT = AsyncTransferQueueClient(client_id, config.controller_info)
51+
_TRANSFER_QUEUE_CLIENT.initialize_storage_manager(manager_type=config.storage_backend, config=config)
52+
53+
return _TRANSFER_QUEUE_CLIENT
5354

5455

5556
def get_transferqueue_client() -> "AsyncTransferQueueClient":
5657
return _TRANSFER_QUEUE_CLIENT
5758

5859

60+
# TODO (TQ): verl will make all actor async, so this can be cleanup later.
5961
def _run_async_in_temp_loop(async_func: Callable[..., Any], *args, **kwargs) -> Any:
6062
# Use a temporary event loop in a new thread because event
6163
# loop may already exist in server mode
@@ -127,7 +129,7 @@ def _update_batchmeta_with_output(output: DataProto, batchmeta: "BatchMeta") ->
127129

128130

129131
def tqbridge(put_data: bool = True):
130-
""" "Creates a decorator for bridging BatchMeta and DataProto.
132+
"""Creates a decorator for bridging BatchMeta and DataProto.
131133
132134
This decorator automatically handles conversions between `BatchMeta` and
133135
`DataProto` in function parameters, and decides whether to sync function

0 commit comments

Comments
 (0)