Skip to content

Commit 2135298

Browse files
authored
Fix code quality tests (#174)
* Revert "Filter examples by `num_chars` to include in a batch (#137)" This reverts commit 8021ce7. * fix test (code quality) * fix code_quality * fix * fix style
1 parent c549da0 commit 2135298

File tree

6 files changed

+12
-10
lines changed

6 files changed

+12
-10
lines changed

bsmetadata/experiments/sample.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from datasets import load_dataset
44
from torch.utils.data import DataLoader
5+
from transformers import PreTrainedTokenizerBase
56

67
from bsmetadata.input_pipeline import DataConfig
7-
from transformers import PreTrainedTokenizerBase
88

99

1010
@dataclass

bsmetadata/experiments/with_metadata.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from datasets import config, load_dataset
77
from torch.utils.data import DataLoader
88
from tqdm.auto import tqdm
9+
from transformers import default_data_collator
910

1011
from bsmetadata.metadata_utils import add_metadata_and_chunk_examples, get_metadata_types, random_sample_metadata
11-
from transformers import default_data_collator
1212

1313

1414
logger = logging.getLogger(__name__)

bsmetadata/experiments/with_metadata_datasetv2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from datasets import DatasetDict
1010
from torch.utils.data import DataLoader
1111
from tqdm.auto import tqdm
12+
from transformers import default_data_collator
1213

1314
from bsmetadata.experiments.datasetv2 import get_files, load_dataset_by_files
1415
from bsmetadata.experiments.without_metadata import preprocess_no_metadata
1516
from bsmetadata.metadata_processors import PROCESSORS
1617
from bsmetadata.metadata_utils import add_metadata_and_chunk_examples, random_sample_metadata_v2
17-
from transformers import default_data_collator
1818

1919

2020
logger = logging.getLogger(__name__)
@@ -123,7 +123,7 @@ def remove_num_proc_kwarg(kwargs):
123123

124124
if args.metadata_config.random_sample_metadata_weights is not None:
125125
metadata_type_sample_weights = args.metadata_config.random_sample_metadata_weights
126-
logger.info(f"using metadata_type_sample_weights proviced in args")
126+
logger.info(f"using {metadata_type_sample_weights} proviced in args")
127127
else:
128128

129129
def get_metadata_types(example):

bsmetadata/experiments/without_metadata.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
from datasets import config, load_dataset
55
from torch.utils.data import DataLoader
6-
76
from transformers import default_data_collator
87

98

bsmetadata/metadata_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
from typing import Any, DefaultDict, Dict, List, Optional, Tuple
2121

2222
import numpy as np
23+
from transformers import PreTrainedTokenizerFast
2324

2425
from bsmetadata.metadata_processors import PROCESSORS, MetadataConfig, MetadataProcessor
25-
from transformers import PreTrainedTokenizerFast
2626

2727

2828
logger = logging.getLogger(__name__)

bsmetadata/train.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import hydra
1414
import torch
1515
import torch.nn.functional as F
16+
import wandb
1617
from accelerate import Accelerator
1718
from accelerate.utils import DistributedType, DummyOptim, DummyScheduler
1819
from hydra.core.config_store import ConfigStore
@@ -22,7 +23,6 @@
2223
from transformers import AddedToken, AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler, set_seed
2324
from transformers.trainer_utils import IntervalStrategy
2425

25-
import wandb
2626
from bsmetadata.input_pipeline import DataConfig, get_dataloaders
2727

2828

@@ -297,7 +297,10 @@ def main(args: CFG) -> None:
297297
model, optimizer, dummy_dataloader, scheduler
298298
)
299299
else:
300-
format_fn = lambda x: x
300+
301+
def format_fn(x):
302+
return x
303+
301304
train_dataloader, eval_dataloaders = get_dataloaders(tokenizer, args.data_config)
302305

303306
# Prepare everything
@@ -409,7 +412,7 @@ def save(path):
409412
step = 0
410413
model.train()
411414
# for epoch in range(args.num_train_epochs):
412-
finished = False
415+
# finished = False
413416
if not args.data_config.streaming:
414417
metrics_logger.log({"train_dataloader_length": len(train_dataloader)})
415418

@@ -486,7 +489,7 @@ def get_data_iter():
486489
evaluate_multiple_dateloaders(eval_dataloaders)
487490

488491
if completed_steps >= args.max_train_steps:
489-
finished = True
492+
# finished = True
490493
break
491494
metrics_logger.close()
492495
logger.info("Training finished")

0 commit comments

Comments
 (0)