Skip to content

Commit 53b68c0

Browse files
committed
Create datasets.py
1 parent e106be4 commit 53b68c0

File tree

1 file changed

+65
-0
lines changed

1 file changed

+65
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
# --------------------------------------------------------
7+
# References:
8+
# DeiT: https://github.com/facebookresearch/deit
9+
# --------------------------------------------------------
10+
11+
import os
12+
import PIL
13+
14+
from torchvision import datasets, transforms
15+
16+
from timm.data import create_transform
17+
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
18+
19+
20+
def build_dataset(is_train, args):
21+
transform = build_transform(is_train, args)
22+
23+
root = os.path.join(args.data_path, "train" if is_train else "val")
24+
dataset = datasets.ImageFolder(root, transform=transform)
25+
26+
return dataset
27+
28+
29+
def build_transform(is_train, args):
30+
mean = IMAGENET_DEFAULT_MEAN
31+
std = IMAGENET_DEFAULT_STD
32+
# train transform
33+
if is_train:
34+
# this should always dispatch to transforms_imagenet_train
35+
transform = create_transform(
36+
input_size=args.input_size,
37+
is_training=True,
38+
color_jitter=args.color_jitter,
39+
auto_augment=args.aa,
40+
interpolation="bicubic",
41+
re_prob=args.reprob,
42+
re_mode=args.remode,
43+
re_count=args.recount,
44+
mean=mean,
45+
std=std,
46+
)
47+
return transform
48+
49+
# eval transform
50+
t = []
51+
if args.input_size <= 224:
52+
crop_pct = 224 / 232
53+
else:
54+
crop_pct = 1.0
55+
size = int(args.input_size / crop_pct)
56+
t.append(
57+
transforms.Resize(
58+
size, interpolation=PIL.Image.BICUBIC
59+
), # to maintain same ratio w.r.t. 224 images
60+
)
61+
t.append(transforms.CenterCrop(args.input_size))
62+
63+
t.append(transforms.ToTensor())
64+
t.append(transforms.Normalize(mean, std))
65+
return transforms.Compose(t)

0 commit comments

Comments
 (0)