Skip to content

Commit 9da7e3a

Browse files
committed
Add crop_mode for pretraind config / image transforms. Add support for dynamo compilation to benchmark/train/validate
1 parent 8fca002 commit 9da7e3a

File tree

10 files changed

+310
-55
lines changed

10 files changed

+310
-55
lines changed

benchmark.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@
5656
except ImportError as e:
5757
has_functorch = False
5858

59+
try:
60+
import torch._dynamo
61+
has_dynamo = True
62+
except ImportError:
63+
has_dynamo = False
64+
pass
65+
5966

6067
if torch.cuda.is_available():
6168
torch.backends.cuda.matmul.allow_tf32 = True
@@ -106,13 +113,19 @@
106113
help='Numeric precision. One of (amp, float32, float16, bfloat16, tf32)')
107114
parser.add_argument('--fuser', default='', type=str,
108115
help="Select jit fuser. One of ('', 'te', 'old', 'nvfuser')")
116+
parser.add_argument('--dynamo-backend', default=None, type=str,
117+
help="Select dynamo backend. Default: None")
118+
parser.add_argument('--fast-norm', default=False, action='store_true',
119+
help='enable experimental fast-norm')
120+
121+
# codegen (model compilation) options
109122
scripting_group = parser.add_mutually_exclusive_group()
110123
scripting_group.add_argument('--torchscript', dest='torchscript', action='store_true',
111-
help='convert model torchscript for inference')
124+
help='convert model torchscript for inference')
112125
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
113-
help="Enable AOT Autograd support. (It's recommended to use this option with `--fuser nvfuser` together)")
114-
scripting_group.add_argument('--fast-norm', default=False, action='store_true',
115-
help='enable experimental fast-norm')
126+
help="Enable AOT Autograd optimization.")
127+
scripting_group.add_argument('--dynamo', default=False, action='store_true',
128+
help="Enable Dynamo optimization.")
116129

117130
# train optimizer parameters
118131
parser.add_argument('--opt', default='sgd', type=str, metavar='OPTIMIZER',
@@ -206,6 +219,8 @@ def __init__(
206219
device='cuda',
207220
torchscript=False,
208221
aot_autograd=False,
222+
dynamo=False,
223+
dynamo_backend=None,
209224
precision='float32',
210225
fuser='',
211226
num_warm_iter=10,
@@ -241,14 +256,21 @@ def __init__(
241256
_logger.info('Model %s created, param count: %d' % (model_name, self.param_count))
242257

243258
data_config = resolve_data_config(kwargs, model=self.model, use_test_size=not use_train_size)
259+
self.input_size = data_config['input_size']
260+
self.batch_size = kwargs.pop('batch_size', 256)
261+
244262
self.scripted = False
245263
if torchscript:
246264
self.model = torch.jit.script(self.model)
247265
self.scripted = True
248-
self.input_size = data_config['input_size']
249-
self.batch_size = kwargs.pop('batch_size', 256)
250-
251-
if aot_autograd:
266+
elif dynamo:
267+
assert has_dynamo, "torch._dynamo is needed for --dynamo"
268+
torch._dynamo.reset()
269+
if dynamo_backend is not None:
270+
self.model = torch._dynamo.optimize(dynamo_backend)(self.model)
271+
else:
272+
self.model = torch._dynamo.optimize()(self.model)
273+
elif aot_autograd:
252274
assert has_functorch, "functorch is needed for --aot-autograd"
253275
self.model = memory_efficient_fusion(self.model)
254276

timm/data/config.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,15 @@
55
_logger = logging.getLogger(__name__)
66

77

8-
def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, verbose=False):
8+
def resolve_data_config(
9+
args,
10+
default_cfg=None,
11+
model=None,
12+
use_test_size=False,
13+
verbose=False
14+
):
915
new_config = {}
10-
default_cfg = default_cfg
16+
default_cfg = default_cfg or {}
1117
if not default_cfg and model is not None and hasattr(model, 'default_cfg'):
1218
default_cfg = model.default_cfg
1319

@@ -63,7 +69,7 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
6369
elif default_cfg.get('std', None):
6470
new_config['std'] = default_cfg['std']
6571

66-
# resolve default crop percentage
72+
# resolve default inference crop
6773
crop_pct = DEFAULT_CROP_PCT
6874
if args.get('crop_pct', None):
6975
crop_pct = args['crop_pct']
@@ -74,6 +80,14 @@ def resolve_data_config(args, default_cfg={}, model=None, use_test_size=False, v
7480
crop_pct = default_cfg['crop_pct']
7581
new_config['crop_pct'] = crop_pct
7682

83+
# resolve default crop percentage
84+
crop_mode = DEFAULT_CROP_MODE
85+
if args.get('crop_mode', None):
86+
crop_mode = args['crop_mode']
87+
elif default_cfg.get('crop_mode', None):
88+
crop_mode = default_cfg['crop_mode']
89+
new_config['crop_mode'] = crop_mode
90+
7791
if verbose:
7892
_logger.info('Data processing configuration for current model + dataset:')
7993
for n, v in new_config.items():

timm/data/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
DEFAULT_CROP_PCT = 0.875
2+
DEFAULT_CROP_MODE = 'center'
23
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
34
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
45
IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)

timm/data/loader.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def create_loader(
211211
num_workers=1,
212212
distributed=False,
213213
crop_pct=None,
214+
crop_mode=None,
214215
collate_fn=None,
215216
pin_memory=False,
216217
fp16=False, # deprecated, use img_dtype
@@ -240,6 +241,7 @@ def create_loader(
240241
mean=mean,
241242
std=std,
242243
crop_pct=crop_pct,
244+
crop_mode=crop_mode,
243245
tf_preprocessing=tf_preprocessing,
244246
re_prob=re_prob,
245247
re_mode=re_mode,

timm/data/tf_preprocessing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,13 @@
2222
# limitations under the License.
2323
# ==============================================================================
2424
"""ImageNet preprocessing for MnasNet."""
25-
import tensorflow as tf
25+
import tensorflow.compat.v1 as tf
2626
import numpy as np
2727

2828
IMAGE_SIZE = 224
2929
CROP_PADDING = 32
3030

31+
tf.compat.v1.disable_eager_execution()
3132

3233
def distorted_bounding_box_crop(image_bytes,
3334
bbox,

timm/data/transforms.py

Lines changed: 148 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
import math
2+
import numbers
3+
import random
4+
import warnings
5+
from typing import List, Sequence
6+
17
import torch
28
import torchvision.transforms.functional as F
39
try:
@@ -6,9 +12,6 @@
612
except ImportError:
713
has_interpolation_mode = False
814
from PIL import Image
9-
import warnings
10-
import math
11-
import random
1215
import numpy as np
1316

1417

@@ -96,6 +99,19 @@ def interp_mode_to_str(mode):
9699
_RANDOM_INTERPOLATION = (str_to_interp_mode('bilinear'), str_to_interp_mode('bicubic'))
97100

98101

102+
def _setup_size(size, error_msg):
103+
if isinstance(size, numbers.Number):
104+
return int(size), int(size)
105+
106+
if isinstance(size, Sequence) and len(size) == 1:
107+
return size[0], size[0]
108+
109+
if len(size) != 2:
110+
raise ValueError(error_msg)
111+
112+
return size
113+
114+
99115
class RandomResizedCropAndInterpolation:
100116
"""Crop the given PIL Image to random size and aspect ratio with random interpolation.
101117
@@ -195,3 +211,132 @@ def __repr__(self):
195211
format_string += ', ratio={0}'.format(tuple(round(r, 4) for r in self.ratio))
196212
format_string += ', interpolation={0})'.format(interpolate_str)
197213
return format_string
214+
215+
216+
def center_crop_or_pad(img: torch.Tensor, output_size: List[int], fill=0) -> torch.Tensor:
217+
"""Center crops and/or pads the given image.
218+
If the image is torch Tensor, it is expected
219+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
220+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
221+
222+
Args:
223+
img (PIL Image or Tensor): Image to be cropped.
224+
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int,
225+
it is used for both directions.
226+
fill (int, Tuple[int]): Padding color
227+
228+
Returns:
229+
PIL Image or Tensor: Cropped image.
230+
"""
231+
if isinstance(output_size, numbers.Number):
232+
output_size = (int(output_size), int(output_size))
233+
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
234+
output_size = (output_size[0], output_size[0])
235+
236+
_, image_height, image_width = F.get_dimensions(img)
237+
crop_height, crop_width = output_size
238+
239+
if crop_width > image_width or crop_height > image_height:
240+
padding_ltrb = [
241+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
242+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
243+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
244+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
245+
]
246+
img = F.pad(img, padding_ltrb, fill=fill)
247+
_, image_height, image_width = F.get_dimensions(img)
248+
if crop_width == image_width and crop_height == image_height:
249+
return img
250+
251+
crop_top = int(round((image_height - crop_height) / 2.0))
252+
crop_left = int(round((image_width - crop_width) / 2.0))
253+
return F.crop(img, crop_top, crop_left, crop_height, crop_width)
254+
255+
256+
class CenterCropOrPad(torch.nn.Module):
257+
"""Crops the given image at the center.
258+
If the image is torch Tensor, it is expected
259+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
260+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
261+
262+
Args:
263+
size (sequence or int): Desired output size of the crop. If size is an
264+
int instead of sequence like (h, w), a square crop (size, size) is
265+
made. If provided a sequence of length 1, it will be interpreted as (size[0], size[0]).
266+
"""
267+
268+
def __init__(self, size, fill=0):
269+
super().__init__()
270+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
271+
self.fill = fill
272+
273+
def forward(self, img):
274+
"""
275+
Args:
276+
img (PIL Image or Tensor): Image to be cropped.
277+
278+
Returns:
279+
PIL Image or Tensor: Cropped image.
280+
"""
281+
return center_crop_or_pad(img, self.size, fill=self.fill)
282+
283+
def __repr__(self) -> str:
284+
return f"{self.__class__.__name__}(size={self.size})"
285+
286+
287+
class ResizeKeepRatio:
288+
""" Resize and Keep Ratio
289+
"""
290+
291+
def __init__(
292+
self,
293+
size,
294+
longest=0.,
295+
interpolation='bilinear',
296+
fill=0,
297+
):
298+
if isinstance(size, (list, tuple)):
299+
self.size = tuple(size)
300+
else:
301+
self.size = (size, size)
302+
self.interpolation = str_to_interp_mode(interpolation)
303+
self.longest = float(longest)
304+
self.fill = fill
305+
306+
@staticmethod
307+
def get_params(img, target_size, longest):
308+
"""Get parameters
309+
310+
Args:
311+
img (PIL Image): Image to be cropped.
312+
target_size (Tuple[int, int]): Size of output
313+
Returns:
314+
tuple: params (h, w) and (l, r, t, b) to be passed to ``resize`` and ``pad`` respectively
315+
"""
316+
source_size = img.size[::-1] # h, w
317+
h, w = source_size
318+
target_h, target_w = target_size
319+
ratio_h = h / target_h
320+
ratio_w = w / target_w
321+
ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (1. - longest)
322+
size = [round(x / ratio) for x in source_size]
323+
return size
324+
325+
def __call__(self, img):
326+
"""
327+
Args:
328+
img (PIL Image): Image to be cropped and resized.
329+
330+
Returns:
331+
PIL Image: Resized, padded to at least target size, possibly cropped to exactly target size
332+
"""
333+
size = self.get_params(img, self.size, self.longest)
334+
img = F.resize(img, size, self.interpolation)
335+
return img
336+
337+
def __repr__(self):
338+
interpolate_str = interp_mode_to_str(self.interpolation)
339+
format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
340+
format_string += f', interpolation={interpolate_str})'
341+
format_string += f', longest={self.longest:.3f})'
342+
return format_string

0 commit comments

Comments
 (0)