Skip to content

Commit 193f8ee

Browse files
authored
Fix typo in val_batch_size and remove unused imports
Fix typo in val_batch_size and remove unused imports
1 parent 82c78b5 commit 193f8ee

File tree

1 file changed

+2
-13
lines changed

1 file changed

+2
-13
lines changed

3.test_cases/pytorch/FSDP/src/train.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,24 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: MIT-0
33

4-
import datetime
54
import functools
65
import math
7-
import re
86
import time
97

10-
import numpy as np
118
import torch
129
from torch import optim
1310
import torch.distributed as dist
1411
import torch.utils.data
1512

16-
import transformers
17-
from transformers import AutoModelForCausalLM, AutoTokenizer
18-
from datasets import load_dataset
19-
2013
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
2114
from torch.distributed.fsdp import MixedPrecision
2215
from torch.distributed.fsdp import ShardingStrategy
2316
from torch.distributed.fsdp import CPUOffload
24-
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
25-
from torch.utils.data import DataLoader
17+
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
2618

27-
from model_utils.concat_dataset import ConcatTokensDataset
2819
from model_utils.train_utils import (get_model_config,
2920
compute_num_params,
3021
get_transformer_layer,
31-
get_sharding_strategy,
32-
get_backward_fetch_policy,
3322
apply_activation_checkpoint,
3423
get_param_groups_by_weight_decay,
3524
get_logger,
@@ -268,7 +257,7 @@ def main(args):
268257
val_dataloader = create_streaming_dataloader(args.dataset,
269258
args.tokenizer,
270259
name=args.dataset_config_name,
271-
batch_size=args.train_batch_size,
260+
batch_size=args.val_batch_size,
272261
split='validation')
273262

274263
train(model,

0 commit comments

Comments
 (0)