Skip to content

Commit 2da2c33

Browse files
committed
Force autoaugment to run on CPU.
Related issue: #378, #161
1 parent 3e7d7b7 commit 2da2c33

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

efficientdet/aug/autoaugment.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,31 +1498,31 @@ def _parse_policy_info(name, prob, level, replace_value, augmentation_hparams):
14981498
# Check to see if prob is passed into function. This is used for operations
14991499
# where we alter bboxes independently.
15001500
# pytype:disable=wrong-arg-types
1501-
if 'prob' in inspect.getargspec(func)[0]:
1501+
if 'prob' in inspect.getfullargspec(func)[0]:
15021502
args = tuple([prob] + list(args))
15031503
# pytype:enable=wrong-arg-types
15041504

15051505
# Add in replace arg if it is required for the function that is being called.
1506-
if 'replace' in inspect.getargspec(func)[0]:
1506+
if 'replace' in inspect.getfullargspec(func)[0]:
15071507
# Make sure replace is the final argument
1508-
assert 'replace' == inspect.getargspec(func)[0][-1]
1508+
assert 'replace' == inspect.getfullargspec(func)[0][-1]
15091509
args = tuple(list(args) + [replace_value])
15101510

15111511
# Add bboxes as the second positional argument for the function if it does
15121512
# not already exist.
1513-
if 'bboxes' not in inspect.getargspec(func)[0]:
1513+
if 'bboxes' not in inspect.getfullargspec(func)[0]:
15141514
func = bbox_wrapper(func)
15151515
return (func, prob, args)
15161516

15171517

15181518
def _apply_func_with_prob(func, image, args, prob, bboxes):
15191519
"""Apply `func` to image w/ `args` as input with probability `prob`."""
15201520
assert isinstance(args, tuple)
1521-
assert 'bboxes' == inspect.getargspec(func)[0][1]
1521+
assert 'bboxes' == inspect.getfullargspec(func)[0][1]
15221522

15231523
# If prob is a function argument, then this randomness is being handled
15241524
# inside the function, so make sure it is always called.
1525-
if 'prob' in inspect.getargspec(func)[0]:
1525+
if 'prob' in inspect.getfullargspec(func)[0]:
15261526
prob = 1.0
15271527

15281528
# Apply the function with probability `prob`.
@@ -1666,5 +1666,6 @@ def distort_image_with_autoaugment(image, bboxes, augmentation_name, use_augmix=
16661666
cutout_bbox_const=50,
16671667
translate_bbox_const=120))
16681668

1669-
return build_and_apply_nas_policy(policy, image, bboxes, augmentation_hparams,
1670-
use_augmix, mixture_width, mixture_depth, alpha)
1669+
with tf.device('/cpu:0'):
1670+
return build_and_apply_nas_policy(policy, image, bboxes,
1671+
augmentation_hparams, use_augmix, mixture_width, mixture_depth, alpha)

0 commit comments

Comments
 (0)