Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions infscale/execution/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,6 @@ async def _run_server(self):
# For this we need to run configure() in a thread so the event loop stays responsive
await asyncio.to_thread(
self.dataset.configure,
self._micro_batch_size,
self.device,
self.spec.reqgen_config.params.in_memory,
self.spec.reqgen_config.params.replay,
Expand Down Expand Up @@ -464,7 +463,13 @@ def _init_assets(self) -> None:
)

# load dataset
self.dataset = HuggingFaceDataset(mmd, path, name, split)
self.dataset = HuggingFaceDataset(
mmd,
path,
dataset_name=name,
split=split,
micro_batch_size=self._micro_batch_size,
)
self.device = torch.device(self.spec.device)

# load model intermediate representation
Expand Down
12 changes: 7 additions & 5 deletions infscale/module/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
# This file was modified from
# https://github.com/SymbioticLab/Oobleck/blob/3b7a0c2f19bff0991e623ffbeb8a5b365853bf3a/oobleck/execution/dataset.py

import math
from typing import Optional, Tuple, Type

import torch
Expand Down Expand Up @@ -51,15 +50,19 @@ def __init__(
dataset_name: Optional[str] = None,
split: Optional[str] = "test",
max_seq_length: Optional[int] = None,
micro_batch_size: int = 1,
):
"""Initialize the class."""
self.micro_batch_size = micro_batch_size

if mmd.model_group == ModelGroup.LANG:
self.tokenizer, self.dataset = HuggingFaceDataset.create_language_dataset(
mmd.name,
dataset_path,
dataset_name,
split,
max_seq_length,
micro_batch_size,
)

def collate_fn(examples):
Expand Down Expand Up @@ -108,11 +111,8 @@ def collate_fn(examples):
self.model_group = mmd.model_group
self._curr_batch: Tensor = None

def configure(
self, micro_batch_size: int, device: torch.device, in_memory: bool, replay: int
) -> None:
def configure(self, device: torch.device, in_memory: bool, replay: int) -> None:
"""Configure dataset."""
self.micro_batch_size = micro_batch_size
self.device = device
self._in_memory = in_memory
self._replay = int(replay)
Expand Down Expand Up @@ -243,6 +243,7 @@ def create_language_dataset(
dataset_name: Optional[str],
split: Optional[str] = None,
max_seq_length: Optional[int] = None,
micro_batch_size: int = 1,
) -> Tuple[Type[PreTrainedTokenizerBase], Dataset]:
"""Create language dataset."""
tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -272,6 +273,7 @@ def tokenize_function(examples):
batched=True,
remove_columns=column_names,
load_from_cache_file=True,
batch_size=micro_batch_size,
)

return tokenizer, tokenized_dataset