Skip to content

Commit 6adcdd0

Browse files
gortizjicopybara-github
authored andcommitted
Allow the option to modify the width of the WideResNet models using a fractionary value.
PiperOrigin-RevId: 475772983
1 parent b91d032 commit 6adcdd0

File tree

3 files changed

+29
-17
lines changed

3 files changed

+29
-17
lines changed

uncertainty_baselines/datasets/augmix.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ def mixup(batch_size, aug_params, images, labels):
140140
aug_params: Dict of data augmentation hyper parameters.
141141
images: A batch of images of shape [batch_size, ...]
142142
labels: A batch of labels of shape [batch_size, num_classes]
143-
144143
Returns:
145144
A tuple of (images, labels) with the same dimensions as the input with
146145
Mixup regularization applied.
@@ -192,15 +191,16 @@ def mixup(batch_size, aug_params, images, labels):
192191
labels = tf.reshape(
193192
tf.tile(labels, [1, aug_count + 1]), [batch_size, aug_count + 1, -1])
194193
labels_mix = (
195-
labels * mix_weight +
196-
tf.gather(labels, mixup_index) * (1. - mix_weight))
194+
labels * mix_weight + tf.gather(labels, mixup_index) *
195+
(1. - mix_weight))
197196
labels_mix = tf.reshape(
198197
tf.transpose(labels_mix, [1, 0, 2]), [batch_size * (aug_count + 1), -1])
199198
else:
200199
labels_mix = (
201-
labels * mix_weight +
202-
tf.gather(labels, mixup_index) * (1. - mix_weight))
203-
return images_mix, labels_mix
200+
labels * mix_weight + tf.gather(labels, mixup_index) *
201+
(1. - mix_weight))
202+
203+
return images_mix, labels_mix
204204

205205

206206
def adaptive_mixup(batch_size, aug_params, images, labels):
@@ -215,7 +215,6 @@ def adaptive_mixup(batch_size, aug_params, images, labels):
215215
aug_params: Dict of data augmentation hyper parameters.
216216
images: A batch of images of shape [batch_size, ...]
217217
labels: A batch of labels of shape [batch_size, num_classes]
218-
219218
Returns:
220219
A tuple of (images, labels) with the same dimensions as the input with
221220
Mixup regularization applied.
@@ -229,8 +228,8 @@ def adaptive_mixup(batch_size, aug_params, images, labels):
229228
# Need to filter out elements in alpha which equal to 0.
230229
greater_zero_indicator = tf.cast(alpha > 0, alpha.dtype)
231230
less_one_indicator = tf.cast(alpha < 1, alpha.dtype)
232-
valid_alpha_indicator = tf.cast(
233-
greater_zero_indicator * less_one_indicator, tf.bool)
231+
valid_alpha_indicator = tf.cast(greater_zero_indicator * less_one_indicator,
232+
tf.bool)
234233
sampled_alpha = tf.where(valid_alpha_indicator, alpha, 0.1)
235234
mix_weight = tfd.Beta(sampled_alpha, sampled_alpha).sample()
236235
mix_weight = tf.where(valid_alpha_indicator, mix_weight, alpha)
@@ -253,4 +252,5 @@ def adaptive_mixup(batch_size, aug_params, images, labels):
253252
images_mix = (
254253
images * images_mix_weight + images[::-1] * (1. - images_mix_weight))
255254
labels_mix = labels * mix_weight + labels[::-1] * (1. - mix_weight)
256-
return images_mix, labels_mix
255+
256+
return images_mix, labels_mix

uncertainty_baselines/datasets/base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,7 @@ def _add_example_id(enumerate_id, example):
267267
def _load(self,
268268
*,
269269
preprocess_fn: Optional[PreProcessFn] = None,
270+
process_batch_fn: Optional[PreProcessFn] = None,
270271
batch_size: int = -1) -> tf.data.Dataset:
271272
"""Transforms the dataset from builder.as_dataset() to batch, repeat, etc.
272273
@@ -278,6 +279,9 @@ def _load(self,
278279
preprocess_fn: an optional preprocessing function, if not provided then a
279280
subclass must define _create_process_example_fn() which will be used to
280281
preprocess the data.
282+
process_batch_fn: an optional processing batch function, if not
283+
provided then _create_process_batch_fn() will be used to generate the
284+
function that will process a batch of data.
281285
batch_size: the batch size to use.
282286
283287
Returns:
@@ -372,7 +376,8 @@ def _load(self,
372376
else:
373377
dataset = dataset.batch(batch_size, drop_remainder=self._drop_remainder)
374378

375-
process_batch_fn = self._create_process_batch_fn(batch_size) # pylint: disable=assignment-from-none
379+
if process_batch_fn is None:
380+
process_batch_fn = self._create_process_batch_fn(batch_size)
376381
if process_batch_fn:
377382
dataset = dataset.map(
378383
process_batch_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
@@ -406,6 +411,7 @@ def load(
406411
self,
407412
*,
408413
preprocess_fn: Optional[PreProcessFn] = None,
414+
process_batch_fn: Optional[PreProcessFn] = None,
409415
batch_size: int = -1,
410416
strategy: Optional[tf.distribute.Strategy] = None) -> tf.data.Dataset:
411417
"""Function definition to support multi-host dataset sharding.
@@ -431,11 +437,13 @@ def load(
431437
432438
Args:
433439
preprocess_fn: see `load()`.
440+
process_batch_fn: see `load()`.
434441
batch_size: the *global* batch size to use. This should equal
435442
`per_replica_batch_size * num_replica_in_sync`.
436443
strategy: the DistributionStrategy used to shard the dataset. Note that
437444
this is only required if TensorFlow for training, otherwise it can be
438445
ignored.
446+
439447
Returns:
440448
A sharded dataset, with its seed combined with the per-host id.
441449
"""
@@ -445,11 +453,15 @@ def _load_distributed(ctx: tf.distribute.InputContext):
445453
self._seed, ctx.input_pipeline_id)
446454
per_replica_batch_size = ctx.get_per_replica_batch_size(batch_size)
447455
return self._load(
448-
preprocess_fn=preprocess_fn, batch_size=per_replica_batch_size)
456+
preprocess_fn=preprocess_fn,
457+
process_batch_fn=process_batch_fn,
458+
batch_size=per_replica_batch_size)
449459

450460
return strategy.distribute_datasets_from_function(_load_distributed)
451461
else:
452-
return self._load(preprocess_fn=preprocess_fn, batch_size=batch_size)
462+
return self._load(preprocess_fn=preprocess_fn,
463+
process_batch_fn=process_batch_fn,
464+
batch_size=batch_size)
453465

454466

455467
_BaseDatasetClass = Type[TypeVar('B', bound=BaseDataset)]

uncertainty_baselines/models/wide_resnet.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ def _parse_hyperparameters(l2: float, hps: Dict[str, float]):
156156
def wide_resnet(
157157
input_shape: Iterable[int],
158158
depth: int,
159-
width_multiplier: int,
159+
width_multiplier: float,
160160
num_classes: int,
161161
l2: float,
162162
version: int = 2,
@@ -207,7 +207,7 @@ def wide_resnet(
207207
x = tf.keras.layers.Activation('relu')(x)
208208
x = group(
209209
x,
210-
filters=16 * width_multiplier,
210+
filters=round(16 * width_multiplier),
211211
strides=1,
212212
num_blocks=num_blocks,
213213
conv_l2=hps['group_1_conv_l2'],
@@ -216,7 +216,7 @@ def wide_resnet(
216216
seed=seeds[1])
217217
x = group(
218218
x,
219-
filters=32 * width_multiplier,
219+
filters=round(32 * width_multiplier),
220220
strides=2,
221221
num_blocks=num_blocks,
222222
conv_l2=hps['group_2_conv_l2'],
@@ -225,7 +225,7 @@ def wide_resnet(
225225
seed=seeds[2])
226226
x = group(
227227
x,
228-
filters=64 * width_multiplier,
228+
filters=round(64 * width_multiplier),
229229
strides=2,
230230
num_blocks=num_blocks,
231231
conv_l2=hps['group_3_conv_l2'],

0 commit comments

Comments
 (0)