Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ sdist/
target/
var/
venv/
.vscode
1 change: 0 additions & 1 deletion keras_resnet/blocks/_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

This module implements a number of popular one-dimensional residual blocks.
"""

import keras.layers
import keras.regularizers

Expand Down
6 changes: 2 additions & 4 deletions keras_resnet/layers/_batch_normalization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import keras


class BatchNormalization(keras.layers.BatchNormalization):
"""
Identical to keras.layers.BatchNormalization, but adds the option to freeze parameters.
Expand All @@ -12,11 +10,11 @@ def __init__(self, freeze, *args, **kwargs):
# set to non-trainable if freeze is true
self.trainable = not self.freeze

def call(self, *args, **kwargs):
def call(self, inputs, *args, **kwargs):
# Force test mode if frozen, otherwise use default behaviour (i.e., training=None).
if self.freeze:
kwargs['training'] = False
return super(BatchNormalization, self).call(*args, **kwargs)
return super(BatchNormalization, self).call(inputs, *args, **kwargs)

def get_config(self):
config = super(BatchNormalization, self).get_config()
Expand Down
Loading