diff --git a/infscale/execution/pipeline.py b/infscale/execution/pipeline.py index 66663eb..e9376b1 100644 --- a/infscale/execution/pipeline.py +++ b/infscale/execution/pipeline.py @@ -331,7 +331,6 @@ async def _run_server(self): # For this we need to run configure() in a thread so the event loop stays responsive await asyncio.to_thread( self.dataset.configure, - self._micro_batch_size, self.device, self.spec.reqgen_config.params.in_memory, self.spec.reqgen_config.params.replay, @@ -464,7 +463,13 @@ def _init_assets(self) -> None: ) # load dataset - self.dataset = HuggingFaceDataset(mmd, path, name, split) + self.dataset = HuggingFaceDataset( + mmd, + path, + dataset_name=name, + split=split, + micro_batch_size=self._micro_batch_size, + ) self.device = torch.device(self.spec.device) # load model intermediate representation diff --git a/infscale/module/dataset.py b/infscale/module/dataset.py index f31dd83..08816e2 100644 --- a/infscale/module/dataset.py +++ b/infscale/module/dataset.py @@ -23,7 +23,6 @@ # This file was modified from # https://github.com/SymbioticLab/Oobleck/blob/3b7a0c2f19bff0991e623ffbeb8a5b365853bf3a/oobleck/execution/dataset.py -import math from typing import Optional, Tuple, Type import torch @@ -51,8 +50,11 @@ def __init__( dataset_name: Optional[str] = None, split: Optional[str] = "test", max_seq_length: Optional[int] = None, + micro_batch_size: int = 1, ): """Initialize the class.""" + self.micro_batch_size = micro_batch_size + if mmd.model_group == ModelGroup.LANG: self.tokenizer, self.dataset = HuggingFaceDataset.create_language_dataset( mmd.name, @@ -60,6 +62,7 @@ def __init__( dataset_name, split, max_seq_length, + micro_batch_size, ) def collate_fn(examples): @@ -108,11 +111,8 @@ def collate_fn(examples): self.model_group = mmd.model_group self._curr_batch: Tensor = None - def configure( - self, micro_batch_size: int, device: torch.device, in_memory: bool, replay: int - ) -> None: + def configure(self, device: torch.device, in_memory: bool, replay: int) -> None: """Configure dataset.""" - self.micro_batch_size = micro_batch_size self.device = device self._in_memory = in_memory self._replay = int(replay) @@ -243,6 +243,7 @@ def create_language_dataset( dataset_name: Optional[str], split: Optional[str] = None, max_seq_length: Optional[int] = None, + micro_batch_size: int = 1, ) -> Tuple[Type[PreTrainedTokenizerBase], Dataset]: """Create language dataset.""" tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -272,6 +273,7 @@ def tokenize_function(examples): batched=True, remove_columns=column_names, load_from_cache_file=True, + batch_size=micro_batch_size, ) return tokenizer, tokenized_dataset