Skip to content

Commit d3fce33

Browse files
authored
Merge pull request #112 from andreped/batch-norm-fix
Added method to replace BN layers [no ci]
2 parents c90109b + 068a507 commit d3fce33

File tree

7 files changed

+196
-2
lines changed

7 files changed

+196
-2
lines changed

.github/workflows/codecov.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ jobs:
7070
--cov=gradient_accumulator tests/test_expected_result.py \
7171
--cov=gradient_accumulator tests/test_mp_batch_norm.py \
7272
--cov=gradient_accumulator tests/test_bn_convnd.py \
73+
--cov=gradient_accumulator tests/test_bn_pretrained_swap.py \
7374
--cov=gradient_accumulator tests/test_model_distribute.py
7475
7576
- name: Lint with flake8

.github/workflows/test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ jobs:
6565
- name: Install tensorflow-datasets
6666
run: |
6767
if [[ ${{ matrix.tf-version }} == 2.12 ]]; then
68-
pip install tensorflow-datasets --upgrade
68+
pip install "tensorflow-datasets<=4.9.2"
6969
else
7070
pip install tensorflow==${{ matrix.tf-version }} "tensorflow-datasets<=4.8.2"
7171
pip install "protobuf<=3.20" --force-reinstall
@@ -96,6 +96,7 @@ jobs:
9696
pytest -v tests/test_adaptive_gradient_clipping.py
9797
pytest -v tests/test_batch_norm.py
9898
pytest -v tests/test_bn_convnd.py
99+
pytest -v tests/test_bn_pretrained_swap.py
99100
pytest -v tests/test_mp_batch_norm.py
100101
pytest -v tests/test_optimizer_distribute.py
101102
pytest -v tests/test_model_distribute.py

docs/examples/batch_normalization.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ the *vanilla* batch normalization layer is the most used.
1414

1515
.. code-block:: python
1616
17+
import tensorflow as tf
1718
from gradient_accumulator import GradientAccumulateModel, AccumBatchNormalization
1819
1920
# sets it here as we will set it for both the layer and model wrapper
@@ -32,6 +33,27 @@ the *vanilla* batch normalization layer is the most used.
3233
model = GradientAccumulateModel(accum_steps=accum_steps, inputs=model.input, outputs=model.output)
3334
3435
36+
You can also easily replace the existing Batch Norm layers in a
37+
pretrained model, i.e., MobileNetV2. Below is an example on how to do that:
38+
39+
40+
.. code-block:: python
41+
42+
import tensorflow as tf
43+
from gradient_accumulator import GradientAccumulateModel
44+
from gradient_accumulator.layers import AccumBatchNormalization
45+
from gradient_accumulator.utils import replace_batchnorm_layers
46+
47+
accum_steps = 4
48+
49+
# replace BN layer with AccumBatchNormalization
50+
model = tf.keras.applications.MobileNetV2(input_shape(28, 28, 3))
51+
model = replace_batchnorm_layers(model, accum_steps=accum_steps)
52+
53+
# add gradient accumulation to existing model
54+
model = GradientAccumulateModel(accum_steps=accum_steps, inputs=model.input, outputs=model.output)
55+
56+
3557
Note that Batch Normalization is a unique layer in Keras.
3658
It has two sets of variables. The first two `mean` and
3759
`variance` are updated during the *forward pass*, whereas

gradient_accumulator/utils.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import tensorflow as tf
2+
3+
from .layers import AccumBatchNormalization
4+
5+
6+
def replace_batchnorm_layers(model, accum_steps, position="replace"):
7+
# Auxiliary dictionary to describe the network graph
8+
network_dict = {"input_layers_of": {}, "new_output_tensor_of": {}}
9+
10+
# Set the input layers of each layer
11+
for layer in model.layers:
12+
for node in layer._outbound_nodes:
13+
layer_name = node.outbound_layer.name
14+
if layer_name not in network_dict["input_layers_of"]:
15+
network_dict["input_layers_of"].update(
16+
{layer_name: [layer.name]}
17+
)
18+
else:
19+
network_dict["input_layers_of"][layer_name].append(layer.name)
20+
21+
# Set the output tensor of the input layer
22+
network_dict["new_output_tensor_of"].update(
23+
{model.layers[0].name: model.input}
24+
)
25+
26+
# Iterate over all layers after the input
27+
model_outputs = []
28+
iter_ = 0
29+
for layer in model.layers[1:]:
30+
31+
# Determine input tensors
32+
layer_input = [
33+
network_dict["new_output_tensor_of"][layer_aux]
34+
for layer_aux in network_dict["input_layers_of"][layer.name]
35+
]
36+
if len(layer_input) == 1:
37+
layer_input = layer_input[0]
38+
39+
# Insert layer if name matches
40+
if isinstance(layer, tf.keras.layers.BatchNormalization):
41+
if position == "replace":
42+
x = layer_input
43+
else:
44+
raise ValueError("position must be: replace")
45+
46+
# build new layer
47+
new_layer = AccumBatchNormalization(
48+
accum_steps=accum_steps,
49+
name="AccumBatchNormalization_" + str(iter_),
50+
)
51+
new_layer.build(input_shape=layer.input_shape)
52+
53+
iter_ += 1
54+
55+
# set weights in new layer to match old layer
56+
new_layer.accum_mean = layer.moving_mean
57+
new_layer.moving_mean = layer.moving_mean
58+
59+
new_layer.accum_variance = layer.moving_variance
60+
new_layer.moving_variance = layer.moving_variance
61+
62+
# forward step
63+
x = new_layer(x)
64+
65+
else:
66+
x = layer(layer_input)
67+
68+
# Set new output tensor (original one/the one of the inserted layer)
69+
network_dict["new_output_tensor_of"].update({layer.name: x})
70+
71+
# Save tensor in output list if it is output in initial model
72+
if layer_name in model.output_names:
73+
model_outputs.append(x)
74+
75+
return tf.keras.Model(inputs=model.inputs, outputs=x)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="gradient-accumulator",
8-
version="0.5.1",
8+
version="0.5.2",
99
author="André Pedersen and David Bouget and Javier Pérez de Frutos and Tor-Arne Schmidt Nordmo",
1010
author_email="[email protected]",
1111
description="Package for gradient accumulation in TensorFlow",

tests/test_bn_pretrained_swap.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import os
2+
import random as python_random
3+
4+
import numpy as np
5+
import tensorflow as tf
6+
import tensorflow_datasets as tfds
7+
from tensorflow.keras.models import load_model
8+
9+
from gradient_accumulator import GradientAccumulateModel
10+
from gradient_accumulator.layers import AccumBatchNormalization
11+
from gradient_accumulator.utils import replace_batchnorm_layers
12+
13+
from .utils import gray2rgb
14+
from .utils import normalize_img
15+
from .utils import reset
16+
from .utils import resizeImage
17+
18+
19+
def test_swap_layer(
20+
custom_bn: bool = True, bs: int = 100, accum_steps: int = 1, epochs: int = 1
21+
):
22+
# load dataset
23+
(ds_train, ds_test), ds_info = tfds.load(
24+
"mnist",
25+
split=["train", "test"],
26+
shuffle_files=True,
27+
as_supervised=True,
28+
with_info=True,
29+
)
30+
31+
# build train pipeline
32+
ds_train = ds_train.map(normalize_img)
33+
ds_train = ds_train.map(gray2rgb)
34+
ds_train = ds_train.map(resizeImage)
35+
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
36+
ds_train = ds_train.batch(bs)
37+
ds_train = ds_train.prefetch(1)
38+
39+
# build test pipeline
40+
ds_test = ds_test.map(normalize_img)
41+
ds_test = ds_test.map(gray2rgb)
42+
ds_test = ds_test.map(resizeImage)
43+
ds_test = ds_test.batch(bs)
44+
ds_test = ds_test.prefetch(1)
45+
46+
# create model
47+
base_model = tf.keras.applications.MobileNetV2(input_shape=(32, 32, 3), weights="imagenet", include_top=False)
48+
base_model = replace_batchnorm_layers(base_model, accum_steps=accum_steps)
49+
50+
input_ = tf.keras.layers.Input(shape=(32, 32, 3))
51+
x = base_model(input_)
52+
x = tf.keras.layers.Dense(10, activation="softmax")(x)
53+
model = tf.keras.Model(inputs=input_, outputs=x)
54+
55+
# wrap model to use gradient accumulation
56+
if accum_steps > 1:
57+
model = GradientAccumulateModel(
58+
accum_steps=accum_steps, inputs=model.input, outputs=model.output
59+
)
60+
61+
# compile model
62+
model.compile(
63+
optimizer=tf.keras.optimizers.SGD(1e-2),
64+
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
65+
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
66+
)
67+
68+
# train model
69+
model.fit(
70+
ds_train,
71+
epochs=epochs,
72+
validation_data=ds_test,
73+
steps_per_epoch=4,
74+
validation_steps=4,
75+
)
76+
77+
model.save("./trained_model")
78+
79+
# load trained model and test
80+
del model
81+
trained_model = load_model("./trained_model", compile=True)
82+
83+
result = trained_model.evaluate(ds_test, verbose=1)
84+
print(result)
85+
return result

tests/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,16 @@ def normalize_img(image, label):
8686
return tf.cast(image, tf.float32) / 255.0, label
8787

8888

89+
def gray2rgb(image, label):
90+
"""Converts images from gray to RGB."""
91+
return tf.concat([image, image, image], axis=-1), label
92+
93+
94+
def resizeImage(image, label, output_shape=(32, 32)):
95+
"""Resizes images."""
96+
return tf.image.resize(image, output_shape, method="nearest"), label
97+
98+
8999
def run_experiment(bs=50, accum_steps=2, epochs=1, modeloropt="opt"):
90100
# load dataset
91101
(ds_train, ds_test), ds_info = tfds.load(

0 commit comments

Comments
 (0)