Skip to content

Commit 6394d26

Browse files
joecummingsmori360
authored andcommitted
Yield per-document RoPE position ids from dataset (pytorch#2560)
1 parent a230cb5 commit 6394d26

File tree

5 files changed

+90
-4
lines changed

5 files changed

+90
-4
lines changed

tests/unit_tests/test_dataset_checkpointing.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ def test_c4_resumption(self):
5555
assert torch.equal(
5656
input_ids["input"], expected_input_ids["input"]
5757
)
58+
assert torch.equal(
59+
input_ids["positions"],
60+
expected_input_ids["positions"],
61+
)
5862
assert torch.equal(labels, expected_labels)
5963

6064
def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank):

torchtitan/components/validate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,17 @@ def post_dataloading_process(
185185
# extra_kwargs are.
186186
extra_kwargs: dict[str, Any] = {}
187187

188+
# TODO: deduplicate with Trainer.post_dataloading_process which has
189+
# the same logic; extract a shared function to prevent further drift.
190+
# For causal attention the whole packed sequence is one document,
191+
# so sequential RoPE positions (positions=None) are correct.
192+
model_config = getattr(model_parts[0], "config", None)
193+
layer = getattr(model_config, "layer", None)
194+
attn_config = getattr(layer, "attention", None) if layer else None
195+
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")
196+
if attn_mask_type != "block_causal":
197+
extra_inputs.pop("positions", None)
198+
188199
try:
189200
# pyrefly: ignore [not-callable]
190201
extra_kwargs["attention_masks"] = cast(

torchtitan/hf_datasets/text_datasets.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(
9696
# Variables for checkpointing
9797
self._sample_idx = 0
9898
self._token_buffer: list[int] = []
99+
self._position_buffer: list[int] = []
99100

100101
def _get_data_iter(self):
101102
# For map-style datasets, resume by skipping to the correct index
@@ -119,15 +120,27 @@ def __iter__(self):
119120
sample_text, add_bos=True, add_eos=True
120121
)
121122
self._token_buffer.extend(sample_tokens)
123+
# Per-document positions reset at document boundaries,
124+
# matching inference frameworks (e.g. vLLM) that start
125+
# positions at 0 per request. Positions wrap at seq_len
126+
# to stay within the RoPE cache, effectively chunking
127+
# long documents into seq_len-sized segments.
128+
# TODO: make overflow policy configurable (chunk / truncate / drop).
129+
self._position_buffer.extend(
130+
i % self.seq_len for i in range(len(sample_tokens))
131+
)
122132
self._sample_idx += 1
123133

124134
while len(self._token_buffer) >= max_buffer_token_len:
125135
x = torch.LongTensor(self._token_buffer[:max_buffer_token_len])
126-
# update tokens to the remaining tokens
136+
pos = torch.LongTensor(self._position_buffer[:max_buffer_token_len])
137+
# update buffers to the remaining tokens
127138
self._token_buffer = self._token_buffer[max_buffer_token_len:]
139+
self._position_buffer = self._position_buffer[max_buffer_token_len:]
128140
input = x[:-1]
129141
label = x[1:]
130-
yield {"input": input}, label
142+
positions = pos[:-1]
143+
yield {"input": input, "positions": positions}, label
131144

132145
if not self.infinite:
133146
logger.warning(f"Dataset {self.dataset_name} has run out of data")
@@ -145,6 +158,15 @@ def __iter__(self):
145158

146159
def load_state_dict(self, state_dict):
147160
self._token_buffer = state_dict["token_buffer"]
161+
if "position_buffer" not in state_dict:
162+
logger.warning(
163+
"Checkpoint missing 'position_buffer' key in dataset state. "
164+
"Falling back to empty position buffer. This is expected when "
165+
"resuming from a checkpoint saved before position tracking was "
166+
"added, but may cause incorrect RoPE positions with "
167+
"block_causal attention (document packing)."
168+
)
169+
self._position_buffer = state_dict.get("position_buffer", [])
148170

149171
if isinstance(self._data, Dataset):
150172
self._sample_idx = state_dict["sample_idx"]
@@ -153,7 +175,10 @@ def load_state_dict(self, state_dict):
153175
self._data.load_state_dict(state_dict["data"])
154176

155177
def state_dict(self):
156-
_state_dict: dict[str, Any] = {"token_buffer": self._token_buffer}
178+
_state_dict: dict[str, Any] = {
179+
"token_buffer": self._token_buffer,
180+
"position_buffer": self._position_buffer,
181+
}
157182

158183
if isinstance(self._data, Dataset):
159184
_state_dict["sample_idx"] = self._sample_idx

torchtitan/models/common/rope.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from typing import Literal
1010

1111
import torch
12+
from torch.distributed.tensor import DTensor, Replicate, Shard
1213

1314
from torchtitan.protocols.module import Module
1415

@@ -289,6 +290,43 @@ def _rotate_half(x: torch.Tensor) -> torch.Tensor:
289290
return torch.cat((-x2, x1), dim=-1)
290291

291292

293+
def _maybe_wrap_positions(
294+
positions: torch.Tensor | None,
295+
x: torch.Tensor,
296+
) -> torch.Tensor | None:
297+
"""Wrap positions as a DTensor deriving mesh and placements from x (xq/xk).
298+
299+
TODO: In a full DTensor rewrite, positions should be made a DTensor
300+
in/right after dataloading, together with inputs and labels.
301+
302+
When TP uses use_local_output=False (DeepSeek V3, Qwen3, GPT-OSS),
303+
x is a DTensor but positions is a plain tensor. The downstream
304+
torch.gather requires both operands to be the same type.
305+
306+
Positions (bsz, seqlen) has fewer dimensions than x (bsz, seqlen,
307+
n_heads, head_dim), so we only preserve Shard placements for shared
308+
dimensions. Shard dims beyond positions' rank (e.g. Shard(2) for TP
309+
on heads) become Replicate.
310+
"""
311+
if (
312+
positions is not None
313+
and isinstance(x, DTensor)
314+
and not isinstance(positions, DTensor)
315+
):
316+
ndim = positions.ndim
317+
placements = tuple(
318+
p if not isinstance(p, Shard) or p.dim < ndim else Replicate()
319+
for p in x.placements
320+
)
321+
positions = DTensor.from_local(
322+
positions,
323+
x.device_mesh,
324+
placements,
325+
run_check=False,
326+
)
327+
return positions
328+
329+
292330
# TODO: consolidate apply_rotary_emb_complex and apply_rotary_emb_single_complex
293331
def apply_rotary_emb_complex(
294332
xq: torch.Tensor,
@@ -304,6 +342,7 @@ def apply_rotary_emb_complex(
304342
freqs_cis: (max_seqlen, head_dim // 2) complex
305343
positions: optional position indices
306344
"""
345+
positions = _maybe_wrap_positions(positions, xq)
307346
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
308347
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
309348
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, xq_, positions)
@@ -324,6 +363,7 @@ def apply_rotary_emb_single_complex(
324363
freqs_cis: (max_seqlen, head_dim // 2) complex
325364
positions: optional position indices
326365
"""
366+
positions = _maybe_wrap_positions(positions, x)
327367
dtype = x.dtype
328368
x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2))
329369
freqs_cis = _reshape_for_broadcast_complex(freqs_cis, x, positions)
@@ -345,6 +385,7 @@ def apply_rotary_emb_cos_sin(
345385
rope_cache: (max_seqlen, head_dim * 2) with cos and sin concatenated
346386
positions: optional position indices
347387
"""
388+
positions = _maybe_wrap_positions(positions, xq)
348389
head_dim = xq.shape[-1]
349390
rope_cache = _reshape_for_broadcast_cos_sin(rope_cache, xq, positions)
350391
cos = rope_cache[..., :head_dim].to(device=xq.device)

torchtitan/trainer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,9 +599,14 @@ def post_dataloading_process(
599599
# extra_kwargs are.
600600
extra_kwargs: dict[str, Any] = {}
601601

602-
# TODO: improve the logic on obtaining attention masks
602+
# For causal attention the whole packed sequence is one document,
603+
# so sequential RoPE positions (positions=None) are correct.
603604
layer = getattr(self.model_config, "layer", None)
604605
attn_config = getattr(layer, "attention", None) if layer else None
606+
attn_mask_type = getattr(attn_config, "attn_mask_type", "causal")
607+
if attn_mask_type != "block_causal":
608+
extra_inputs.pop("positions", None)
609+
605610
attn_backend = getattr(attn_config, "attn_backend", "sdpa")
606611
if attn_backend in ["flex", "varlen"]:
607612
assert (

0 commit comments

Comments
 (0)