Skip to content

Commit 3d2db56

Browse files
Modified InputSpec of SpectralNormalization layer (#21335)
* Modified InputSpec of SpectralNormalization layer * Added unit test for for higher dimension
1 parent c219734 commit 3d2db56

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

keras/src/layers/normalization/spectral_normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(self, layer, power_iterations=1, **kwargs):
5252

5353
def build(self, input_shape):
5454
super().build(input_shape)
55-
self.input_spec = InputSpec(shape=[None] + list(input_shape[1:]))
55+
self.input_spec = InputSpec(min_ndim=1, axes={-1: input_shape[-1]})
5656

5757
if hasattr(self.layer, "kernel"):
5858
self.kernel = self.layer.kernel

keras/src/layers/normalization/spectral_normalization_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,20 @@ def test_basic_spectralnorm(self):
3535
run_training_check=False,
3636
)
3737

38+
@pytest.mark.requires_trainable_backend
39+
def test_spectralnorm_higher_dim(self):
40+
self.run_layer_test(
41+
layers.SpectralNormalization,
42+
init_kwargs={"layer": layers.Dense(2)},
43+
input_data=np.random.uniform(size=(10, 3, 4, 5)),
44+
expected_output_shape=(10, 3, 4, 2),
45+
expected_num_trainable_weights=2,
46+
expected_num_non_trainable_weights=1,
47+
expected_num_seed_generators=0,
48+
expected_num_losses=0,
49+
supports_masking=False,
50+
)
51+
3852
def test_invalid_power_iterations(self):
3953
with self.assertRaisesRegex(
4054
ValueError, "`power_iterations` should be greater than zero."

0 commit comments

Comments
 (0)