Skip to content

Commit 877d94c

Browse files
committed
pref: truncate and pad to max length for TPU
1 parent 2c55192 commit 877d94c

File tree

5 files changed

+25
-11
lines changed

5 files changed

+25
-11
lines changed

bsmetadata/experiments/with_metadata.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import functools
22
import logging
33

4+
from accelerate import DistributedType
45
from datasets import load_dataset
56
from torch.utils.data import DataLoader
6-
from transformers import default_data_collator
7+
from transformers import DataCollatorWithPadding, default_data_collator
78

89
from bsmetadata.metadata_utils import add_metadata_and_chunk_examples
910

@@ -124,15 +125,18 @@ def create_labels_column(examples):
124125
val_dataset = lm_datasets["validation"]
125126

126127
# DataLoaders creation:
128+
data_collator = default_data_collator
129+
if args.distributed_type == DistributedType.TPU:
130+
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len)
127131
train_dataloader = DataLoader(
128132
train_dataset,
129133
shuffle=True,
130-
collate_fn=default_data_collator,
134+
collate_fn=data_collator,
131135
batch_size=args.per_device_train_batch_size,
132136
)
133137
val_dataloader1 = DataLoader(
134138
val_dataset,
135-
collate_fn=default_data_collator,
139+
collate_fn=data_collator,
136140
batch_size=args.per_device_eval_batch_size,
137141
)
138142
return train_dataloader, {"val1": val_dataloader1}

bsmetadata/experiments/without_metadata.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import logging
22

3+
from accelerate import DistributedType
34
from datasets import load_dataset
45
from torch.utils.data import DataLoader
5-
from transformers import default_data_collator
6+
from transformers import DataCollatorWithPadding, default_data_collator
67

78

89
logger = logging.getLogger(__name__)
@@ -94,7 +95,8 @@ def get_dataloaders(tokenizer, args):
9495
text_column_name = "text" if "text" in column_names else column_names[0]
9596

9697
def tokenize_function(examples):
97-
return tokenizer(examples[text_column_name])
98+
# max_length=None => use the model max length (it's actually the default)
99+
return tokenizer(examples[text_column_name], truncation=True, max_length=None)
98100

99101
tokenized_datasets = raw_datasets.map(
100102
tokenize_function,
@@ -157,15 +159,18 @@ def group_texts(examples):
157159
val_dataset = lm_datasets["validation"]
158160

159161
# DataLoaders creation:
162+
data_collator = default_data_collator
163+
if args.distributed_type == DistributedType.TPU:
164+
data_collator = DataCollatorWithPadding(tokenizer, padding="max_length", max_length=args.max_seq_len)
160165
train_dataloader = DataLoader(
161166
train_dataset,
162167
shuffle=True,
163-
collate_fn=default_data_collator,
168+
collate_fn=data_collator,
164169
batch_size=args.per_device_train_batch_size,
165170
)
166171
val_dataloader1 = DataLoader(
167172
val_dataset,
168-
collate_fn=default_data_collator,
173+
collate_fn=data_collator,
169174
batch_size=args.per_device_eval_batch_size,
170175
)
171176
return train_dataloader, {"val1": val_dataloader1}

bsmetadata/input_pipeline.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from dataclasses import dataclass, field
22
from typing import List, Optional
33

4+
from accelerate import DistributedType
5+
46

57
@dataclass
68
class DataConfig:
@@ -62,6 +64,7 @@ class DataConfig:
6264
block_size: Optional[int] = field(
6365
default=None, metadata={"help": "Optional input sequence length after tokenization."}
6466
)
67+
distributed_type: DistributedType = field(default=DistributedType.NO)
6568

6669

6770
def get_dataloaders(tokenizer, cfg: DataConfig):

bsmetadata/metadata_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def add_metadata_and_chunk_examples(
5151
if add_metadata:
5252
# Get the global metadata prefix that is prepended to each training example.
5353
global_metadata_prefix = create_global_metadata_prefix(example, cfg)
54-
global_metadata_prefix_encoded = tokenizer.encode_plus(global_metadata_prefix).input_ids
54+
global_metadata_prefix_encoded = tokenizer.encode_plus(global_metadata_prefix, truncation=True).input_ids
5555
else:
5656
global_metadata_prefix_encoded = []
5757

@@ -66,7 +66,7 @@ def add_metadata_and_chunk_examples(
6666
if global_metadata_prefix_encoded:
6767
text_with_local_metadata = " " + text_with_local_metadata
6868
char_level_metadata_mask = [False] + char_level_metadata_mask
69-
text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata)
69+
text_with_local_metadata_encoded = tokenizer.encode_plus(text_with_local_metadata, truncation=True)
7070

7171
def is_metadata(idx: int) -> bool:
7272
char_span = text_with_local_metadata_encoded.token_to_chars(idx)

bsmetadata/train.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,12 +104,14 @@ def loss_fn(batch, outputs, metadata_mask=None):
104104
return loss
105105

106106

107-
@hydra.main(config_name="config")
107+
@hydra.main(config_path=None, config_name="config")
108108
def main(args: CFG) -> None:
109+
accelerator = Accelerator()
110+
args.data_config.distributed_type = accelerator.distributed_type
111+
109112
print(OmegaConf.to_yaml(args))
110113

111114
set_seed(args.seed)
112-
accelerator = Accelerator()
113115
is_local_main_process = accelerator.is_local_main_process
114116
tqdm = partial(original_tqdm, disable=not is_local_main_process, position=0)
115117

0 commit comments

Comments
 (0)