2626import math
2727from absl import logging
2828import tensorflow .compat .v1 as tf
29+ import tensorflow_probability as tfp
2930import hparams_config
31+ import numpy as np
3032
3133try :
3234 # addon image_ops are simpler, but they have some issues on GPU and TPU.
@@ -1545,9 +1547,30 @@ def select_and_apply_random_policy(policies, image, bboxes):
15451547 lambda : (image , bboxes ))
15461548 return (image , bboxes )
15471549
1550+ def select_and_apply_random_policy_augmix (policies , image , bboxes , mixture_width = 3 , mixture_depth = - 1 , alpha = 1 ):
1551+ """Select a random policy from `policies` and apply it to `image`."""
1552+ policy_to_select = tf .random_uniform ([], maxval = len (policies ), dtype = tf .int32 )
1553+ # Note that using tf.case instead of tf.conds would result in significantly
1554+ # larger graphs and would even break export for some larger policies.
1555+ ws = tfp .distributions .Dirichlet ([alpha ] * mixture_width ).sample ()
1556+ m = tfp .distributions .Beta (alpha , alpha ).sample ()
1557+ mix = tf .zeros_like (image , dtype = tf .float32 )
1558+ for j in range (mixture_width ):
1559+ aug_image = image
1560+ depth = mixture_depth if mixture_depth > 0 else np .random .randint (1 , 4 )
1561+ for _ in range (depth ):
1562+ for (i , policy ) in enumerate (policies ):
1563+ aug_image , bboxes = tf .cond (
1564+ tf .equal (i , policy_to_select ),
1565+ lambda selected_policy = policy : selected_policy (aug_image , bboxes ),
1566+ lambda : (aug_image , bboxes ))
1567+ mix += ws [j ] * tf .cast (aug_image , tf .float32 )
1568+ mixed = tf .cast ((1 - m ) * tf .cast (image , tf .float32 ) + m * mix , tf .uint8 )
1569+ return (mixed , bboxes )
15481570
15491571def build_and_apply_nas_policy (policies , image , bboxes ,
1550- augmentation_hparams ):
1572+ augmentation_hparams , use_augmix = False ,
1573+ mixture_width = 3 , mixture_depth = - 1 , alpha = 1 ):
15511574 """Build a policy from the given policies passed in and apply to image.
15521575
15531576 Args:
@@ -1559,6 +1582,11 @@ def build_and_apply_nas_policy(policies, image, bboxes,
15591582 bboxes: tf.Tensor of shape [N, 4] representing ground truth boxes that are
15601583 normalized between [0, 1].
15611584 augmentation_hparams: Hparams associated with the NAS learned policy.
1585+ use_augmix: whether use augmix[https://arxiv.org/pdf/1912.02781.pdf]
1586+ width: Width of augmentation chain
1587+ depth: Depth of augmentation chain. -1 enables stochastic depth uniformly
1588+ from [1, 3]
1589+ alpha: Probability coefficient for Beta and Dirichlet distributions.
15621590
15631591 Returns:
15641592 A version of image that now has data augmentation applied to it based on
@@ -1592,14 +1620,19 @@ def final_policy(image_, bboxes_):
15921620 return image_ , bboxes_
15931621 return final_policy
15941622 tf_policies .append (make_final_policy (tf_policy ))
1623+ if use_augmix :
1624+ augmented_images , augmented_bboxes = select_and_apply_random_policy_augmix (
1625+ tf_policies , image , bboxes , mixture_width , mixture_depth , alpha )
1626+ else :
1627+ augmented_images , augmented_bboxes = select_and_apply_random_policy (
1628+ tf_policies , image , bboxes )
15951629
1596- augmented_images , augmented_bboxes = select_and_apply_random_policy (
1597- tf_policies , image , bboxes )
15981630 # If no bounding boxes were specified, then just return the images.
15991631 return (augmented_images , augmented_bboxes )
16001632
16011633
1602- def distort_image_with_autoaugment (image , bboxes , augmentation_name ):
1634+ def distort_image_with_autoaugment (image , bboxes , augmentation_name , use_augmix = False ,
1635+ mixture_width = 3 , mixture_depth = - 1 , alpha = 1 ):
16031636 """Applies the AutoAugment policy to `image` and `bboxes`.
16041637
16051638 Args:
@@ -1633,4 +1666,5 @@ def distort_image_with_autoaugment(image, bboxes, augmentation_name):
16331666 cutout_bbox_const = 50 ,
16341667 translate_bbox_const = 120 ))
16351668
1636- return build_and_apply_nas_policy (policy , image , bboxes , augmentation_hparams )
1669+ return build_and_apply_nas_policy (policy , image , bboxes , augmentation_hparams ,
1670+ use_augmix , mixture_width , mixture_depth , alpha )
0 commit comments