Skip to content

Commit 15ebc2d

Browse files
authored
add augmix (#362)
* add augmix * fix params
1 parent 53563d6 commit 15ebc2d

File tree

4 files changed

+44
-6
lines changed

4 files changed

+44
-6
lines changed

efficientdet/aug/autoaugment.py

Lines changed: 39 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@
2626
import math
2727
from absl import logging
2828
import tensorflow.compat.v1 as tf
29+
import tensorflow_probability as tfp
2930
import hparams_config
31+
import numpy as np
3032

3133
try:
3234
# addon image_ops are simpler, but they have some issues on GPU and TPU.
@@ -1545,9 +1547,30 @@ def select_and_apply_random_policy(policies, image, bboxes):
15451547
lambda: (image, bboxes))
15461548
return (image, bboxes)
15471549

1550+
def select_and_apply_random_policy_augmix(policies, image, bboxes, mixture_width=3, mixture_depth=-1, alpha=1):
1551+
"""Select a random policy from `policies` and apply it to `image`."""
1552+
policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32)
1553+
# Note that using tf.case instead of tf.conds would result in significantly
1554+
# larger graphs and would even break export for some larger policies.
1555+
ws = tfp.distributions.Dirichlet([alpha] * mixture_width).sample()
1556+
m = tfp.distributions.Beta(alpha, alpha).sample()
1557+
mix = tf.zeros_like(image, dtype=tf.float32)
1558+
for j in range(mixture_width):
1559+
aug_image = image
1560+
depth = mixture_depth if mixture_depth > 0 else np.random.randint(1, 4)
1561+
for _ in range(depth):
1562+
for (i, policy) in enumerate(policies):
1563+
aug_image, bboxes = tf.cond(
1564+
tf.equal(i, policy_to_select),
1565+
lambda selected_policy=policy: selected_policy(aug_image, bboxes),
1566+
lambda: (aug_image, bboxes))
1567+
mix += ws[j] * tf.cast(aug_image, tf.float32)
1568+
mixed = tf.cast((1 - m) * tf.cast(image, tf.float32) + m * mix, tf.uint8)
1569+
return (mixed, bboxes)
15481570

15491571
def build_and_apply_nas_policy(policies, image, bboxes,
1550-
augmentation_hparams):
1572+
augmentation_hparams, use_augmix=False,
1573+
mixture_width=3, mixture_depth=-1, alpha=1):
15511574
"""Build a policy from the given policies passed in and apply to image.
15521575
15531576
Args:
@@ -1559,6 +1582,11 @@ def build_and_apply_nas_policy(policies, image, bboxes,
15591582
bboxes: tf.Tensor of shape [N, 4] representing ground truth boxes that are
15601583
normalized between [0, 1].
15611584
augmentation_hparams: Hparams associated with the NAS learned policy.
1585+
use_augmix: whether use augmix[https://arxiv.org/pdf/1912.02781.pdf]
1586+
width: Width of augmentation chain
1587+
depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
1588+
from [1, 3]
1589+
alpha: Probability coefficient for Beta and Dirichlet distributions.
15621590
15631591
Returns:
15641592
A version of image that now has data augmentation applied to it based on
@@ -1592,14 +1620,19 @@ def final_policy(image_, bboxes_):
15921620
return image_, bboxes_
15931621
return final_policy
15941622
tf_policies.append(make_final_policy(tf_policy))
1623+
if use_augmix:
1624+
augmented_images, augmented_bboxes = select_and_apply_random_policy_augmix(
1625+
tf_policies, image, bboxes, mixture_width, mixture_depth, alpha)
1626+
else:
1627+
augmented_images, augmented_bboxes = select_and_apply_random_policy(
1628+
tf_policies, image, bboxes)
15951629

1596-
augmented_images, augmented_bboxes = select_and_apply_random_policy(
1597-
tf_policies, image, bboxes)
15981630
# If no bounding boxes were specified, then just return the images.
15991631
return (augmented_images, augmented_bboxes)
16001632

16011633

1602-
def distort_image_with_autoaugment(image, bboxes, augmentation_name):
1634+
def distort_image_with_autoaugment(image, bboxes, augmentation_name, use_augmix=False,
1635+
mixture_width=3, mixture_depth=-1, alpha=1):
16031636
"""Applies the AutoAugment policy to `image` and `bboxes`.
16041637
16051638
Args:
@@ -1633,4 +1666,5 @@ def distort_image_with_autoaugment(image, bboxes, augmentation_name):
16331666
cutout_bbox_const=50,
16341667
translate_bbox_const=120))
16351668

1636-
return build_and_apply_nas_policy(policy, image, bboxes, augmentation_hparams)
1669+
return build_and_apply_nas_policy(policy, image, bboxes, augmentation_hparams,
1670+
use_augmix, mixture_width, mixture_depth, alpha)

efficientdet/aug/autoaugment_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_autoaugment_policy(self):
3030
image = tf.placeholder(tf.uint8, shape=[640, 640, 3])
3131
bboxes = tf.placeholder(tf.float32, shape=[4, 4])
3232
autoaugment.distort_image_with_autoaugment(image, bboxes, 'test')
33+
autoaugment.distort_image_with_autoaugment(image, bboxes, 'test', use_augmix=True)
3334

3435

3536
if __name__ == '__main__':

efficientdet/dataloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ def _dataset_parser(value):
277277
if params.get('autoaugment_policy', None) and self._is_training:
278278
from aug import autoaugment # pylint: disable=g-import-not-at-top
279279
image, boxes = autoaugment.distort_image_with_autoaugment(
280-
image, boxes, params['autoaugment_policy'])
280+
image, boxes, params['autoaugment_policy'], params['use_augmix'], *params['augmix_params'])
281281

282282
input_processor = DetectionInputProcessor(
283283
image, params['image_size'], boxes, classes)

efficientdet/hparams_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,9 @@ def default_detection_configs():
177177
h.train_scale_min = 0.1
178178
h.train_scale_max = 2.0
179179
h.autoaugment_policy = None
180+
h.use_augmix = False
181+
# mixture_width, mixture_depth, alpha
182+
h.augmix_params = (3, -1, 1)
180183

181184
# dataset specific parameters
182185
h.num_classes = 90

0 commit comments

Comments
 (0)