Skip to content

Commit 0e66d08

Browse files
committed
pref: pad to max length for TPU
1 parent aa57ad6 commit 0e66d08

File tree

4 files changed

+21
-8
lines changed

4 files changed

+21
-8
lines changed

bsmetadata/experiments/with_metadata.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import functools
22
import logging
33

4+
from accelerate import DistributedType
45
from datasets import load_dataset
56
from torch.utils.data import DataLoader
6-
from transformers import default_data_collator
7+
from transformers import DataCollatorWithPadding, default_data_collator
78

89
from bsmetadata.metadata_utils import add_metadata_and_chunk_examples
910

@@ -124,15 +125,18 @@ def create_labels_column(examples):
124125
val_dataset = lm_datasets["validation"]
125126

126127
# DataLoaders creation:
128+
data_collator = default_data_collator
129+
if args.distributed_type == DistributedType.TPU:
130+
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len)
127131
train_dataloader = DataLoader(
128132
train_dataset,
129133
shuffle=True,
130-
collate_fn=default_data_collator,
134+
collate_fn=data_collator,
131135
batch_size=args.per_device_train_batch_size,
132136
)
133137
val_dataloader1 = DataLoader(
134138
val_dataset,
135-
collate_fn=default_data_collator,
139+
collate_fn=data_collator,
136140
batch_size=args.per_device_eval_batch_size,
137141
)
138142
return train_dataloader, {"val1": val_dataloader1}

bsmetadata/experiments/without_metadata.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
22

3+
from accelerate import DistributedType
34
from datasets import load_dataset
45
from torch.utils.data import DataLoader
5-
from transformers import default_data_collator
6+
from transformers import DataCollatorWithPadding, default_data_collator
67

78

89
logger = logging.getLogger(__name__)
@@ -157,15 +158,18 @@ def group_texts(examples):
157158
val_dataset = lm_datasets["validation"]
158159

159160
# DataLoaders creation:
161+
data_collator = default_data_collator
162+
if args.distributed_type == DistributedType.TPU:
163+
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len)
160164
train_dataloader = DataLoader(
161165
train_dataset,
162166
shuffle=True,
163-
collate_fn=default_data_collator,
167+
collate_fn=data_collator,
164168
batch_size=args.per_device_train_batch_size,
165169
)
166170
val_dataloader1 = DataLoader(
167171
val_dataset,
168-
collate_fn=default_data_collator,
172+
collate_fn=data_collator,
169173
batch_size=args.per_device_eval_batch_size,
170174
)
171175
return train_dataloader, {"val1": val_dataloader1}

bsmetadata/input_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass, field
22
from typing import List, Optional
33

4+
from accelerate import DistributedType
5+
46

57
@dataclass
68
class DataConfig:
@@ -62,6 +64,7 @@ class DataConfig:
6264
block_size: Optional[int] = field(
6365
default=None, metadata={"help": "Optional input sequence length after tokenization."}
6466
)
67+
distributed_type: DistributedType = field(default=DistributedType.NO)
6568

6669

6770
def get_dataloaders(tokenizer, cfg: DataConfig):

bsmetadata/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,14 @@ def loss_fn(batch, outputs, metadata_mask=None):
104104
return loss
105105

106106

107-
@hydra.main(config_name="config")
107+
@hydra.main(config_path=None, config_name="config")
108108
def main(args: CFG) -> None:
109+
accelerator = Accelerator()
110+
args.data_config.distributed_type = accelerator.distributed_type
111+
109112
print(OmegaConf.to_yaml(args))
110113

111114
set_seed(args.seed)
112-
accelerator = Accelerator()
113115
is_local_main_process = accelerator.is_local_main_process
114116
tqdm = partial(original_tqdm, disable=not is_local_main_process)
115117

0 commit comments

Comments
 (0)