|
1 | 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. |
2 | 2 | # SPDX-License-Identifier: MIT-0 |
3 | 3 |
|
4 | | -import datetime |
5 | 4 | import functools |
6 | 5 | import math |
7 | | -import re |
8 | 6 | import time |
9 | 7 |
|
10 | | -import numpy as np |
11 | 8 | import torch |
12 | 9 | from torch import optim |
13 | 10 | import torch.distributed as dist |
14 | 11 | import torch.utils.data |
15 | 12 |
|
16 | | -import transformers |
17 | | -from transformers import AutoModelForCausalLM, AutoTokenizer |
18 | | -from datasets import load_dataset |
19 | | - |
20 | 13 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP |
21 | 14 | from torch.distributed.fsdp import MixedPrecision |
22 | 15 | from torch.distributed.fsdp import ShardingStrategy |
23 | 16 | from torch.distributed.fsdp import CPUOffload |
24 | | -from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy |
25 | | -from torch.utils.data import DataLoader |
| 17 | +from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy |
26 | 18 |
|
27 | | -from model_utils.concat_dataset import ConcatTokensDataset |
28 | 19 | from model_utils.train_utils import (get_model_config, |
29 | 20 | compute_num_params, |
30 | 21 | get_transformer_layer, |
31 | | - get_sharding_strategy, |
32 | | - get_backward_fetch_policy, |
33 | 22 | apply_activation_checkpoint, |
34 | 23 | get_param_groups_by_weight_decay, |
35 | 24 | get_logger, |
@@ -268,7 +257,7 @@ def main(args): |
268 | 257 | val_dataloader = create_streaming_dataloader(args.dataset, |
269 | 258 | args.tokenizer, |
270 | 259 | name=args.dataset_config_name, |
271 | | - batch_size=args.train_batch_size, |
| 260 | + batch_size=args.val_batch_size, |
272 | 261 | split='validation') |
273 | 262 |
|
274 | 263 | train(model, |
|
0 commit comments