Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xtuner/v1/data_proto/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from .rl_data import RolloutState, SampleParams
from .sequence_context import SequenceContext


__all__ = [
"SequenceContext",
"RolloutState",
"SampleParams",
]
7 changes: 4 additions & 3 deletions xtuner/v1/data_proto/rl_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,14 @@ class RolloutState(BaseModel):
model_config = ConfigDict(extra="forbid")

# --- 数据 ---
message_uid: int # 通过计算原始的message的哈希值得到的id,一组的数据为同一个prompt_id
message: list[dict[str, Any]] # dataset输出,需要在AgentLoop中转换成input_ids
prompt_ids: list[int] # 原始 prompt的token ids
data_source: dict[str, Any] | None = None
mm_info: MultimodalInfo | None = None
reward_model: dict[str, Any] | None = None

message_uid: int | None = None # 通过计算原始的message的哈希值得到的id,一组的数据为同一个prompt_id
num_tokens: int | None = None # 用于 cache 管理

# --- InferEngine 输入 ---
session_uid: int | None = None
tokens: list[int] # 每一次推理引擎的实际输入
Expand All @@ -86,7 +87,7 @@ class RolloutState(BaseModel):
response: str | None = None
response_ids: list[int] | None = None
logprobs: list[float] | None = None
routed_experts: list[int] | RayObjectRef | None = None # type: ignore[valid-type]
routed_experts: list[int] | RayObjectRef | None = None # type: ignore[valid-type]
finish_reason: str | None = None

# --- Judger 输出 ---
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/datasets/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,11 @@ def _tokenize_by_offset(
) -> dict:
line = data.decode()
tokenized = tokenize_fn(json.loads(line))
return {"num_tokens": tokenized["num_tokens"]}
if hasattr(tokenized, "num_tokens"):
num_tokens = tokenized.num_tokens
else:
num_tokens = tokenized["num_tokens"]
return {"num_tokens": num_tokens}

def count_tokens(self, offsets, cache_dir=None):
self.tokenize_fn.set_state("cache")
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/datasets/rl_tokenize_fn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .rl_tokenize_fn import RLTokenizeFnConfig
from .text_tokenize_fn import RLTextTokenizeFnConfig


__all__ = [
"RLTokenizeFnConfig",
"RLTextTokenizeFnConfig",
]
172 changes: 0 additions & 172 deletions xtuner/v1/datasets/rl_tokenize_fn/rl_tokenize_fn.py

This file was deleted.

78 changes: 78 additions & 0 deletions xtuner/v1/datasets/rl_tokenize_fn/text_tokenize_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
from pydantic import BaseModel, ConfigDict

from transformers import PreTrainedTokenizer
from xtuner.v1.data_proto import RolloutState
from xtuner.v1.utils import get_logger

from ..utils import CachableTokenizeFunction


logger = get_logger()


class RLTextTokenizeFn(CachableTokenizeFunction[RolloutState]):
def __init__(
self,
tokenizer: PreTrainedTokenizer,
max_length: int | None = None,
):
super().__init__(tokenizer)
self.max_length = max_length

def __call__(self, item: dict, **kwargs) -> RolloutState:
"""example:
item = {
"data_source": data_source,
"prompt": [
{
"role": "user",
"content": question,
}
],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {
"split": split,
"index": idx,
"answer": answer_raw,
"question": question_raw,
},
}
"""

extra_info = item.get("extra_info", {})
message = item["prompt"]

raw_prompt = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
data = self.tokenizer(raw_prompt, add_special_tokens=False)
prompt_token_ids = data["input_ids"]
num_tokens = len(data["input_ids"])

if self.state == "cache":
if self.max_length is not None and num_tokens > self.max_length:
num_tokens = 0 # will be filtered out by the dataset filter
else:
if self.max_length is not None:
assert num_tokens <= self.max_length, f"num_tokens {num_tokens} > max_length {self.max_length}"

rollout_state = RolloutState(
prompt_ids=prompt_token_ids,
message=message,
data_source=item.get("data_source", "default"),
reward_model=item.get("reward_model", {}),
num_tokens=num_tokens,
extra_fields=extra_info,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

也改名叫 RolloutState.extra_info 保持一致吧

)
return rollout_state

def hash(self) -> str:
raise ValueError("不应该触发这个方法, 因为 RLTokenizeFn 不需要缓存。")


class RLTextTokenizeFnConfig(BaseModel):
model_config = ConfigDict(title="Base RL dataset config for xtuner", extra="forbid")
max_length: int | None = None

def build(self, tokenizer: PreTrainedTokenizer, **kwargs) -> RLTextTokenizeFn:
return RLTextTokenizeFn(tokenizer=tokenizer, max_length=self.max_length)
Loading