|
23 | 23 | # This file was modified from |
24 | 24 | # https://github.com/SymbioticLab/Oobleck/blob/3b7a0c2f19bff0991e623ffbeb8a5b365853bf3a/oobleck/execution/dataset.py |
25 | 25 |
|
26 | | -import math |
27 | 26 | from typing import Optional, Tuple, Type |
28 | 27 |
|
29 | 28 | import torch |
@@ -51,15 +50,19 @@ def __init__( |
51 | 50 | dataset_name: Optional[str] = None, |
52 | 51 | split: Optional[str] = "test", |
53 | 52 | max_seq_length: Optional[int] = None, |
| 53 | + micro_batch_size: int = 1, |
54 | 54 | ): |
55 | 55 | """Initialize the class.""" |
| 56 | + self.micro_batch_size = micro_batch_size |
| 57 | + |
56 | 58 | if mmd.model_group == ModelGroup.LANG: |
57 | 59 | self.tokenizer, self.dataset = HuggingFaceDataset.create_language_dataset( |
58 | 60 | mmd.name, |
59 | 61 | dataset_path, |
60 | 62 | dataset_name, |
61 | 63 | split, |
62 | 64 | max_seq_length, |
| 65 | + micro_batch_size, |
63 | 66 | ) |
64 | 67 |
|
65 | 68 | def collate_fn(examples): |
@@ -108,11 +111,8 @@ def collate_fn(examples): |
108 | 111 | self.model_group = mmd.model_group |
109 | 112 | self._curr_batch: Tensor = None |
110 | 113 |
|
111 | | - def configure( |
112 | | - self, micro_batch_size: int, device: torch.device, in_memory: bool, replay: int |
113 | | - ) -> None: |
| 114 | + def configure(self, device: torch.device, in_memory: bool, replay: int) -> None: |
114 | 115 | """Configure dataset.""" |
115 | | - self.micro_batch_size = micro_batch_size |
116 | 116 | self.device = device |
117 | 117 | self._in_memory = in_memory |
118 | 118 | self._replay = int(replay) |
@@ -243,6 +243,7 @@ def create_language_dataset( |
243 | 243 | dataset_name: Optional[str], |
244 | 244 | split: Optional[str] = None, |
245 | 245 | max_seq_length: Optional[int] = None, |
| 246 | + micro_batch_size: int = 1, |
246 | 247 | ) -> Tuple[Type[PreTrainedTokenizerBase], Dataset]: |
247 | 248 | """Create language dataset.""" |
248 | 249 | tokenizer = AutoTokenizer.from_pretrained(model_name) |
@@ -272,6 +273,7 @@ def tokenize_function(examples): |
272 | 273 | batched=True, |
273 | 274 | remove_columns=column_names, |
274 | 275 | load_from_cache_file=True, |
| 276 | + batch_size=micro_batch_size, |
275 | 277 | ) |
276 | 278 |
|
277 | 279 | return tokenizer, tokenized_dataset |
0 commit comments