Skip to content

Using Stateful Dataloader with Split Dataset By Node and DCP for DDP #7927

@conceptofmind

Description

@conceptofmind

Describe the bug

I am trying to determine how to save and load the Stateful Dataloader State with DCP and Split Dataset by Node for DDP.

Currently, I am running into the issue where I am receiving a slow resume.

Neither dataset nor iter(dataset) defines state_dict/load_state_dict so we are naively fast-forwarding your dataset by 5000 steps. For more efficient resumes, please implement `state_dict` and `load_state_dict` in your IterableDataset and/or iterator.

Steps to reproduce the bug

Say we have a streaming dataset:

class StreamingDataset(IterableDataset):
    def __init__(
        self,
        path: str,
        tokenizer: AutoTokenizer,
        name: Optional[str] = None,
        split: str = "train",
        max_length: int = 2048,
        ddp_rank: int = 0,
        ddp_world_size: int = 1,
    ):
        dataset = load_dataset(path, name, split=split, streaming=True)
        self.train_dataset = split_dataset_by_node(
            dataset=dataset, rank=ddp_rank, world_size=ddp_world_size
        )

        self.tokenizer = tokenizer
        self.max_length = max_length

    def __iter__(self):
        for sample in iter(self.train_dataset):
            tokenized = self.tokenizer(
                sample["text"],
                padding="max_length",
                truncation=True,
                max_length=self.max_length,
                return_special_tokens_mask=True,
            )
            yield tokenized

We load that dataset into the Stateful Dataloader:

    trainloader = StatefulDataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        collate_fn=data_collator,
    )

We then have code for checkpointing and resuming the state using DCP:

import os
from typing import Optional

import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict

from blitzbert.utils import print_rank_0


class Checkpoint:
    def __init__(
        self,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer,
        trainloader,
        step: Optional[int] = None,
        epoch: Optional[int] = None,
    ):
        self.model = model
        self.optimizer = optimizer
        self.trainloader = trainloader
        self.step = step
        self.epoch = epoch

    def get_state_dict(self) -> dict:
        model_state_dict, optimizer_state_dict = get_state_dict(
            self.model, self.optimizer
        )
        return {
            "model": model_state_dict,
            "optim": optimizer_state_dict,
            "trainloader": self.trainloader.state_dict(),
            "step": self.step,
            "epoch": self.epoch,
        }


def save_checkpoint(
    args,
    model,
    optimizer,
    trainloader,
    step: Optional[int] = None,
    epoch: Optional[int] = None,
    final_checkpoint: bool = False,
):
    checkpointer = Checkpoint(
        model=model,
        optimizer=optimizer,
        trainloader=trainloader,
        step=step,
        epoch=epoch,
    )

    state_dict = checkpointer.get_state_dict()

    if final_checkpoint:
        print_rank_0("Saving final model")
        
        save_path = os.path.join(args.checkpoint_dir, "final_model")
        
        dcp.save(state_dict, checkpoint_id=save_path)
        dist.barrier()

        single_file_path = os.path.join(args.checkpoint_dir, "final_checkpoint.pth")
        dcp_to_torch_save(save_path, single_file_path)
    else:
        if step % args.checkpointing_steps == 0 and step != 0:
            print_rank_0(f"Saving model at step: {step}")
            save_path = os.path.join(args.checkpoint_dir, f"epoch_{epoch}_step_{step}")
            dcp.save(state_dict, checkpoint_id=save_path)
            dist.barrier()


def load_checkpoint(args, model, optimizer, trainloader):
    if not args.resume_from_checkpoint:
        return 0, 0

    checkpoint_path = args.resume_from_checkpoint
    print_rank_0(f"Resumed from checkpoint: {checkpoint_path}")
    
    checkpointer = Checkpoint(
        model=model,
        optimizer=optimizer,
        trainloader=trainloader,
    )

    state_dict = checkpointer.get_state_dict()

    dcp.load(
        state_dict=state_dict,
        checkpoint_id=checkpoint_path,
    )

    set_state_dict(
        model,
        optimizer,
        model_state_dict=state_dict["model"],
        optim_state_dict=state_dict["optim"],
    )

    trainloader.load_state_dict(state_dict["trainloader"])
    
    step = state_dict["step"]
    epoch = state_dict["epoch"]

    return step, epoch

and then loading the checkpoint:

        completed_steps, current_epoch = load_checkpoint(
            args=args, model=model, optimizer=optimizer, trainloader=trainloader
        )

Expected behavior

If I implement what the warning says:

    def state_dict(self):
        return self.train_dataset.state_dict()

    def load_state_dict(self, state):
        self.train_dataset.load_state_dict(state)

I then get:

[rank0]:     raise RuntimeError(f"Missing key in checkpoint state_dict: {fqn}.")
[rank0]: RuntimeError: Missing key in checkpoint state_dict: trainloader.dataset_state.examples_iterable.examples_iterable.previous_state.

How exactly should one be saving and resuming the Stateful Dataloader with Hugging Face datasets?

Environment info

"datasets>=4.4.1",

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions