Skip to content

Commit dcc104d

Browse files
authored
fix: setting batch size correctly for language dataset (#314)
the map function for generating language dataset takes batch_size as an argument, whose default value is 1000. This creates an issue during tokenization. Tokenizer adds pad tokens incorrectly. We set the batch size correctly here.
1 parent 627409d commit dcc104d

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

infscale/execution/pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,6 @@ async def _run_server(self):
331331
# For this we need to run configure() in a thread so the event loop stays responsive
332332
await asyncio.to_thread(
333333
self.dataset.configure,
334-
self._micro_batch_size,
335334
self.device,
336335
self.spec.reqgen_config.params.in_memory,
337336
self.spec.reqgen_config.params.replay,
@@ -464,7 +463,13 @@ def _init_assets(self) -> None:
464463
)
465464

466465
# load dataset
467-
self.dataset = HuggingFaceDataset(mmd, path, name, split)
466+
self.dataset = HuggingFaceDataset(
467+
mmd,
468+
path,
469+
dataset_name=name,
470+
split=split,
471+
micro_batch_size=self._micro_batch_size,
472+
)
468473
self.device = torch.device(self.spec.device)
469474

470475
# load model intermediate representation

infscale/module/dataset.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
# This file was modified from
2424
# https://github.com/SymbioticLab/Oobleck/blob/3b7a0c2f19bff0991e623ffbeb8a5b365853bf3a/oobleck/execution/dataset.py
2525

26-
import math
2726
from typing import Optional, Tuple, Type
2827

2928
import torch
@@ -51,15 +50,19 @@ def __init__(
5150
dataset_name: Optional[str] = None,
5251
split: Optional[str] = "test",
5352
max_seq_length: Optional[int] = None,
53+
micro_batch_size: int = 1,
5454
):
5555
"""Initialize the class."""
56+
self.micro_batch_size = micro_batch_size
57+
5658
if mmd.model_group == ModelGroup.LANG:
5759
self.tokenizer, self.dataset = HuggingFaceDataset.create_language_dataset(
5860
mmd.name,
5961
dataset_path,
6062
dataset_name,
6163
split,
6264
max_seq_length,
65+
micro_batch_size,
6366
)
6467

6568
def collate_fn(examples):
@@ -108,11 +111,8 @@ def collate_fn(examples):
108111
self.model_group = mmd.model_group
109112
self._curr_batch: Tensor = None
110113

111-
def configure(
112-
self, micro_batch_size: int, device: torch.device, in_memory: bool, replay: int
113-
) -> None:
114+
def configure(self, device: torch.device, in_memory: bool, replay: int) -> None:
114115
"""Configure dataset."""
115-
self.micro_batch_size = micro_batch_size
116116
self.device = device
117117
self._in_memory = in_memory
118118
self._replay = int(replay)
@@ -243,6 +243,7 @@ def create_language_dataset(
243243
dataset_name: Optional[str],
244244
split: Optional[str] = None,
245245
max_seq_length: Optional[int] = None,
246+
micro_batch_size: int = 1,
246247
) -> Tuple[Type[PreTrainedTokenizerBase], Dataset]:
247248
"""Create language dataset."""
248249
tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -272,6 +273,7 @@ def tokenize_function(examples):
272273
batched=True,
273274
remove_columns=column_names,
274275
load_from_cache_file=True,
276+
batch_size=micro_batch_size,
275277
)
276278

277279
return tokenizer, tokenized_dataset

0 commit comments

Comments
 (0)