-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Open
Description
Describe the bug
I am trying to determine how to save and load the Stateful Dataloader State with DCP and Split Dataset by Node for DDP.
Currently, I am running into the issue where I am receiving a slow resume.
Neither dataset nor iter(dataset) defines state_dict/load_state_dict so we are naively fast-forwarding your dataset by 5000 steps. For more efficient resumes, please implement `state_dict` and `load_state_dict` in your IterableDataset and/or iterator.
Steps to reproduce the bug
Say we have a streaming dataset:
class StreamingDataset(IterableDataset):
def __init__(
self,
path: str,
tokenizer: AutoTokenizer,
name: Optional[str] = None,
split: str = "train",
max_length: int = 2048,
ddp_rank: int = 0,
ddp_world_size: int = 1,
):
dataset = load_dataset(path, name, split=split, streaming=True)
self.train_dataset = split_dataset_by_node(
dataset=dataset, rank=ddp_rank, world_size=ddp_world_size
)
self.tokenizer = tokenizer
self.max_length = max_length
def __iter__(self):
for sample in iter(self.train_dataset):
tokenized = self.tokenizer(
sample["text"],
padding="max_length",
truncation=True,
max_length=self.max_length,
return_special_tokens_mask=True,
)
yield tokenizedWe load that dataset into the Stateful Dataloader:
trainloader = StatefulDataLoader(
dataset=train_dataset,
batch_size=args.batch_size,
collate_fn=data_collator,
)We then have code for checkpointing and resuming the state using DCP:
import os
from typing import Optional
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from blitzbert.utils import print_rank_0
class Checkpoint:
def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
trainloader,
step: Optional[int] = None,
epoch: Optional[int] = None,
):
self.model = model
self.optimizer = optimizer
self.trainloader = trainloader
self.step = step
self.epoch = epoch
def get_state_dict(self) -> dict:
model_state_dict, optimizer_state_dict = get_state_dict(
self.model, self.optimizer
)
return {
"model": model_state_dict,
"optim": optimizer_state_dict,
"trainloader": self.trainloader.state_dict(),
"step": self.step,
"epoch": self.epoch,
}
def save_checkpoint(
args,
model,
optimizer,
trainloader,
step: Optional[int] = None,
epoch: Optional[int] = None,
final_checkpoint: bool = False,
):
checkpointer = Checkpoint(
model=model,
optimizer=optimizer,
trainloader=trainloader,
step=step,
epoch=epoch,
)
state_dict = checkpointer.get_state_dict()
if final_checkpoint:
print_rank_0("Saving final model")
save_path = os.path.join(args.checkpoint_dir, "final_model")
dcp.save(state_dict, checkpoint_id=save_path)
dist.barrier()
single_file_path = os.path.join(args.checkpoint_dir, "final_checkpoint.pth")
dcp_to_torch_save(save_path, single_file_path)
else:
if step % args.checkpointing_steps == 0 and step != 0:
print_rank_0(f"Saving model at step: {step}")
save_path = os.path.join(args.checkpoint_dir, f"epoch_{epoch}_step_{step}")
dcp.save(state_dict, checkpoint_id=save_path)
dist.barrier()
def load_checkpoint(args, model, optimizer, trainloader):
if not args.resume_from_checkpoint:
return 0, 0
checkpoint_path = args.resume_from_checkpoint
print_rank_0(f"Resumed from checkpoint: {checkpoint_path}")
checkpointer = Checkpoint(
model=model,
optimizer=optimizer,
trainloader=trainloader,
)
state_dict = checkpointer.get_state_dict()
dcp.load(
state_dict=state_dict,
checkpoint_id=checkpoint_path,
)
set_state_dict(
model,
optimizer,
model_state_dict=state_dict["model"],
optim_state_dict=state_dict["optim"],
)
trainloader.load_state_dict(state_dict["trainloader"])
step = state_dict["step"]
epoch = state_dict["epoch"]
return step, epochand then loading the checkpoint:
completed_steps, current_epoch = load_checkpoint(
args=args, model=model, optimizer=optimizer, trainloader=trainloader
)Expected behavior
If I implement what the warning says:
def state_dict(self):
return self.train_dataset.state_dict()
def load_state_dict(self, state):
self.train_dataset.load_state_dict(state)I then get:
[rank0]: raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank0]: RuntimeError: Missing key in checkpoint state_dict: trainloader.dataset_state.examples_iterable.examples_iterable.previous_state.
How exactly should one be saving and resuming the Stateful Dataloader with Hugging Face datasets?
Environment info
"datasets>=4.4.1",
Metadata
Metadata
Assignees
Labels
No labels