|
13 | 13 | import hydra |
14 | 14 | import torch |
15 | 15 | import torch.nn.functional as F |
| 16 | +import wandb |
16 | 17 | from accelerate import Accelerator |
17 | 18 | from accelerate.utils import DistributedType, DummyOptim, DummyScheduler |
18 | 19 | from hydra.core.config_store import ConfigStore |
|
22 | 23 | from transformers import AddedToken, AutoConfig, AutoModelForCausalLM, AutoTokenizer, get_scheduler, set_seed |
23 | 24 | from transformers.trainer_utils import IntervalStrategy |
24 | 25 |
|
25 | | -import wandb |
26 | 26 | from bsmetadata.input_pipeline import DataConfig, get_dataloaders |
27 | 27 |
|
28 | 28 |
|
@@ -297,7 +297,10 @@ def main(args: CFG) -> None: |
297 | 297 | model, optimizer, dummy_dataloader, scheduler |
298 | 298 | ) |
299 | 299 | else: |
300 | | - format_fn = lambda x: x |
| 300 | + |
| 301 | + def format_fn(x): |
| 302 | + return x |
| 303 | + |
301 | 304 | train_dataloader, eval_dataloaders = get_dataloaders(tokenizer, args.data_config) |
302 | 305 |
|
303 | 306 | # Prepare everything |
@@ -409,7 +412,7 @@ def save(path): |
409 | 412 | step = 0 |
410 | 413 | model.train() |
411 | 414 | # for epoch in range(args.num_train_epochs): |
412 | | - finished = False |
| 415 | + # finished = False |
413 | 416 | if not args.data_config.streaming: |
414 | 417 | metrics_logger.log({"train_dataloader_length": len(train_dataloader)}) |
415 | 418 |
|
@@ -486,7 +489,7 @@ def get_data_iter(): |
486 | 489 | evaluate_multiple_dateloaders(eval_dataloaders) |
487 | 490 |
|
488 | 491 | if completed_steps >= args.max_train_steps: |
489 | | - finished = True |
| 492 | + # finished = True |
490 | 493 | break |
491 | 494 | metrics_logger.close() |
492 | 495 | logger.info("Training finished") |
|
0 commit comments