Skip to content

Commit 5961352

Browse files
authored
add text tokenize fn of rl (#1485)
1 parent b508219 commit 5961352

File tree

6 files changed

+92
-178
lines changed

6 files changed

+92
-178
lines changed

xtuner/v1/data_proto/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
from .rl_data import RolloutState, SampleParams
12
from .sequence_context import SequenceContext
23

34

45
__all__ = [
56
"SequenceContext",
7+
"RolloutState",
8+
"SampleParams",
69
]

xtuner/v1/data_proto/rl_data.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,14 @@ class RolloutState(BaseModel):
6868
model_config = ConfigDict(extra="forbid")
6969

7070
# --- 数据 ---
71-
message_uid: int # 通过计算原始的message的哈希值得到的id,一组的数据为同一个prompt_id
7271
message: list[dict[str, Any]] # dataset输出,需要在AgentLoop中转换成input_ids
7372
prompt_ids: list[int] # 原始 prompt的token ids
7473
data_source: dict[str, Any] | None = None
7574
mm_info: MultimodalInfo | None = None
7675
reward_model: dict[str, Any] | None = None
77-
76+
message_uid: int | None = None # 通过计算原始的message的哈希值得到的id,一组的数据为同一个prompt_id
77+
num_tokens: int | None = None # 用于 cache 管理
78+
7879
# --- InferEngine 输入 ---
7980
session_uid: int | None = None
8081
tokens: list[int] # 每一次推理引擎的实际输入
@@ -86,7 +87,7 @@ class RolloutState(BaseModel):
8687
response: str | None = None
8788
response_ids: list[int] | None = None
8889
logprobs: list[float] | None = None
89-
routed_experts: list[int] | RayObjectRef | None = None # type: ignore[valid-type]
90+
routed_experts: list[int] | RayObjectRef | None = None # type: ignore[valid-type]
9091
finish_reason: str | None = None
9192

9293
# --- Judger 输出 ---

xtuner/v1/datasets/jsonl.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,11 @@ def _tokenize_by_offset(
443443
) -> dict:
444444
line = data.decode()
445445
tokenized = tokenize_fn(json.loads(line))
446-
return {"num_tokens": tokenized["num_tokens"]}
446+
if hasattr(tokenized, "num_tokens"):
447+
num_tokens = tokenized.num_tokens
448+
else:
449+
num_tokens = tokenized["num_tokens"]
450+
return {"num_tokens": num_tokens}
447451

448452
def count_tokens(self, offsets, cache_dir=None):
449453
self.tokenize_fn.set_state("cache")
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from .rl_tokenize_fn import RLTokenizeFnConfig
1+
from .text_tokenize_fn import RLTextTokenizeFnConfig
22

33

44
__all__ = [
5-
"RLTokenizeFnConfig",
5+
"RLTextTokenizeFnConfig",
66
]

xtuner/v1/datasets/rl_tokenize_fn/rl_tokenize_fn.py

Lines changed: 0 additions & 172 deletions
This file was deleted.
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from pydantic import BaseModel, ConfigDict
3+
4+
from transformers import PreTrainedTokenizer
5+
from xtuner.v1.data_proto import RolloutState
6+
from xtuner.v1.utils import get_logger
7+
8+
from ..utils import CachableTokenizeFunction
9+
10+
11+
logger = get_logger()
12+
13+
14+
class RLTextTokenizeFn(CachableTokenizeFunction[RolloutState]):
15+
def __init__(
16+
self,
17+
tokenizer: PreTrainedTokenizer,
18+
max_length: int | None = None,
19+
):
20+
super().__init__(tokenizer)
21+
self.max_length = max_length
22+
23+
def __call__(self, item: dict, **kwargs) -> RolloutState:
24+
"""example:
25+
item = {
26+
"data_source": data_source,
27+
"prompt": [
28+
{
29+
"role": "user",
30+
"content": question,
31+
}
32+
],
33+
"ability": "math",
34+
"reward_model": {"style": "rule", "ground_truth": solution},
35+
"extra_info": {
36+
"split": split,
37+
"index": idx,
38+
"answer": answer_raw,
39+
"question": question_raw,
40+
},
41+
}
42+
"""
43+
44+
extra_info = item.get("extra_info", {})
45+
message = item["prompt"]
46+
47+
raw_prompt = self.tokenizer.apply_chat_template(message, add_generation_prompt=True, tokenize=False)
48+
data = self.tokenizer(raw_prompt, add_special_tokens=False)
49+
prompt_token_ids = data["input_ids"]
50+
num_tokens = len(data["input_ids"])
51+
52+
if self.state == "cache":
53+
if self.max_length is not None and num_tokens > self.max_length:
54+
num_tokens = 0 # will be filtered out by the dataset filter
55+
else:
56+
if self.max_length is not None:
57+
assert num_tokens <= self.max_length, f"num_tokens {num_tokens} > max_length {self.max_length}"
58+
59+
rollout_state = RolloutState(
60+
prompt_ids=prompt_token_ids,
61+
message=message,
62+
data_source=item.get("data_source", "default"),
63+
reward_model=item.get("reward_model", {}),
64+
num_tokens=num_tokens,
65+
extra_fields=extra_info,
66+
)
67+
return rollout_state
68+
69+
def hash(self) -> str:
70+
raise ValueError("不应该触发这个方法, 因为 RLTokenizeFn 不需要缓存。")
71+
72+
73+
class RLTextTokenizeFnConfig(BaseModel):
74+
model_config = ConfigDict(title="Base RL dataset config for xtuner", extra="forbid")
75+
max_length: int | None = None
76+
77+
def build(self, tokenizer: PreTrainedTokenizer, **kwargs) -> RLTextTokenizeFn:
78+
return RLTextTokenizeFn(tokenizer=tokenizer, max_length=self.max_length)

0 commit comments

Comments
 (0)