Skip to content

Commit b990e54

Browse files
authored
Fix: Keep use_cudnn in nested RNNs in Bidirectional layer (#21534)
- the use_cudnn attribute of RNNs such as GRU/LSTM is not serialized and Bidirectional lost it when creating nested copies for forward and backward RNN layers - let's pass it explicitly to the nested layers as it should not be serialized - do nothing for other RNN layers where it's not defined
1 parent b9619c1 commit b990e54

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

keras/src/layers/rnn/bidirectional.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ def __init__(
125125
)
126126
else:
127127
self.backward_layer = backward_layer
128+
# Keep the use_cudnn attribute if defined (not serialized).
129+
if hasattr(layer, "use_cudnn"):
130+
self.forward_layer.use_cudnn = layer.use_cudnn
131+
self.backward_layer.use_cudnn = layer.use_cudnn
128132
self._verify_layer_config()
129133

130134
def force_zero_output_for_mask(layer):

keras/src/layers/rnn/bidirectional_test.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,18 @@ def test_output_shape(self):
260260
output_shape = layer.compute_output_shape(x.shape)
261261
for out, shape in zip(output, output_shape):
262262
self.assertEqual(out.shape, shape)
263+
264+
def test_keeps_use_cudnn(self):
265+
# keep use_cudnn if the layer has it
266+
for rnn_class in [layers.GRU, layers.LSTM]:
267+
for use_cudnn in [True, False, "auto"]:
268+
rnn = rnn_class(1, use_cudnn=use_cudnn)
269+
bidi = layers.Bidirectional(rnn)
270+
self.assertEqual(bidi.forward_layer.use_cudnn, use_cudnn)
271+
self.assertEqual(bidi.backward_layer.use_cudnn, use_cudnn)
272+
273+
# otherwise ignore it
274+
rnn = layers.SimpleRNN(1)
275+
bidi = layers.Bidirectional(rnn)
276+
self.assertFalse(hasattr(bidi.forward_layer, "use_cudnn"))
277+
self.assertFalse(hasattr(bidi.backward_layer, "use_cudnn"))

0 commit comments

Comments
 (0)