1+ import numpy as np
12import tensorflow as tf
23from tensorflow .keras .models import load_model
4+
35from gradient_accumulator import GradientAccumulateModel
46from gradient_accumulator .layers import AccumBatchNormalization
5- import numpy as np
67
78
8- def test_bn_conv2d (custom_bn :bool = True , accum_steps :int = 1 , epochs :int = 1 ):
9+ def test_bn_conv2d (
10+ custom_bn : bool = True , accum_steps : int = 1 , epochs : int = 1
11+ ):
912 # make toy dataset
1013 data = np .random .randint (2 , size = (16 , 8 , 8 , 1 ))
1114 gt = np .expand_dims (np .random .randint (2 , size = 16 ), axis = - 1 )
@@ -19,20 +22,24 @@ def test_bn_conv2d(custom_bn:bool = True, accum_steps:int = 1, epochs:int = 1):
1922 normalization_layer = tf .keras .layers .Activation ("linear" )
2023
2124 # create model
22- model = tf .keras .models .Sequential ([
23- tf .keras .layers .Conv2D (4 , 3 , input_shape = (8 , 8 , 1 )),
24- normalization_layer ,
25- tf .keras .layers .Activation ("relu" ),
26- tf .keras .layers .Flatten (),
27- tf .keras .layers .Dense (4 ),
28- normalization_layer , # @TODO: BN before or after ReLU? Leads to different performance
29- tf .keras .layers .Activation ("relu" ),
30- tf .keras .layers .Dense (1 , activation = "sigmoid" ),
31- ])
25+ model = tf .keras .models .Sequential (
26+ [
27+ tf .keras .layers .Conv2D (4 , 3 , input_shape = (8 , 8 , 1 )),
28+ normalization_layer ,
29+ tf .keras .layers .Activation ("relu" ),
30+ tf .keras .layers .Flatten (),
31+ tf .keras .layers .Dense (4 ),
32+ normalization_layer , # @TODO: BN before or after ReLU? Leads to different performance
33+ tf .keras .layers .Activation ("relu" ),
34+ tf .keras .layers .Dense (1 , activation = "sigmoid" ),
35+ ]
36+ )
3237
3338 # wrap model to use gradient accumulation
3439 if accum_steps > 1 :
35- model = GradientAccumulateModel (accum_steps = accum_steps , inputs = model .input , outputs = model .output )
40+ model = GradientAccumulateModel (
41+ accum_steps = accum_steps , inputs = model .input , outputs = model .output
42+ )
3643
3744 # compile model
3845 model .compile (
@@ -60,7 +67,9 @@ def test_bn_conv2d(custom_bn:bool = True, accum_steps:int = 1, epochs:int = 1):
6067 return result
6168
6269
63- def test_bn_conv3d (custom_bn :bool = True , accum_steps :int = 1 , epochs :int = 1 ):
70+ def test_bn_conv3d (
71+ custom_bn : bool = True , accum_steps : int = 1 , epochs : int = 1
72+ ):
6473 # make toy dataset
6574 data = np .random .randint (2 , size = (16 , 8 , 8 , 8 , 1 ))
6675 gt = np .expand_dims (np .random .randint (2 , size = 16 ), axis = - 1 )
@@ -74,20 +83,24 @@ def test_bn_conv3d(custom_bn:bool = True, accum_steps:int = 1, epochs:int = 1):
7483 normalization_layer = tf .keras .layers .Activation ("linear" )
7584
7685 # create model
77- model = tf .keras .models .Sequential ([
78- tf .keras .layers .Conv3D (4 , 3 , input_shape = (8 , 8 , 8 , 1 )),
79- normalization_layer ,
80- tf .keras .layers .Activation ("relu" ),
81- tf .keras .layers .Flatten (),
82- tf .keras .layers .Dense (4 ),
83- normalization_layer , # @TODO: BN before or after ReLU? Leads to different performance
84- tf .keras .layers .Activation ("relu" ),
85- tf .keras .layers .Dense (1 , activation = "sigmoid" ),
86- ])
86+ model = tf .keras .models .Sequential (
87+ [
88+ tf .keras .layers .Conv3D (4 , 3 , input_shape = (8 , 8 , 8 , 1 )),
89+ normalization_layer ,
90+ tf .keras .layers .Activation ("relu" ),
91+ tf .keras .layers .Flatten (),
92+ tf .keras .layers .Dense (4 ),
93+ normalization_layer , # @TODO: BN before or after ReLU? Leads to different performance
94+ tf .keras .layers .Activation ("relu" ),
95+ tf .keras .layers .Dense (1 , activation = "sigmoid" ),
96+ ]
97+ )
8798
8899 # wrap model to use gradient accumulation
89100 if accum_steps > 1 :
90- model = GradientAccumulateModel (accum_steps = accum_steps , inputs = model .input , outputs = model .output )
101+ model = GradientAccumulateModel (
102+ accum_steps = accum_steps , inputs = model .input , outputs = model .output
103+ )
91104
92105 # compile model
93106 model .compile (
0 commit comments