diff --git a/keras/layers/rnn/cudnn_gru.py b/keras/layers/rnn/cudnn_gru.py index 45c7c91d53e3..e666cf1ea9ff 100644 --- a/keras/layers/rnn/cudnn_gru.py +++ b/keras/layers/rnn/cudnn_gru.py @@ -152,23 +152,34 @@ def _process_batch(self, inputs, initial_state): input_h = initial_state[0] input_h = tf.expand_dims(input_h, axis=0) + + weights = [ + self.kernel[:, : self.units], + self.kernel[:, self.units : self.units * 2], + self.kernel[:, self.units * 2 :], + self.recurrent_kernel[:, : self.units], + self.recurrent_kernel[:, self.units : self.units * 2], + self.recurrent_kernel[:, self.units * 2 :], + ] + + biases = [ + self.bias[: self.units], + self.bias[self.units : self.units * 2], + self.bias[self.units * 2 : self.units * 3], + self.bias[self.units * 3 : self.units * 4], + self.bias[self.units * 4 : self.units * 5], + self.bias[self.units * 5 :], + ] + + if tf.sysconfig.get_build_info()["is_cuda_build"]: + weights[0], weights[1] = weights[1], weights[0] + weights[3], weights[4] = weights[4], weights[3] + biases[0], biases[1] = biases[1], biases[0] + biases[3], biases[4] = biases[4], biases[3] + params = gru_lstm_utils.canonical_to_params( - weights=[ - self.kernel[:, self.units : self.units * 2], - self.kernel[:, : self.units], - self.kernel[:, self.units * 2 :], - self.recurrent_kernel[:, self.units : self.units * 2], - self.recurrent_kernel[:, : self.units], - self.recurrent_kernel[:, self.units * 2 :], - ], - biases=[ - self.bias[self.units : self.units * 2], - self.bias[: self.units], - self.bias[self.units * 2 : self.units * 3], - self.bias[self.units * 4 : self.units * 5], - self.bias[self.units * 3 : self.units * 4], - self.bias[self.units * 5 :], - ], + weights=weights, + biases=biases, shape=self._vector_shape, ) @@ -185,6 +196,7 @@ def _process_batch(self, inputs, initial_state): if self.stateful or self.return_state: h = h[0] + if self.return_sequences: if self.time_major: output = outputs @@ -192,6 +204,7 @@ def _process_batch(self, inputs, initial_state): output = tf.transpose(outputs, perm=(1, 0, 2)) else: output = outputs[-1] + return output, [h] def get_config(self): diff --git a/keras/layers/rnn/cudnn_lstm.py b/keras/layers/rnn/cudnn_lstm.py index 69ae8e96af6b..e07909991cee 100644 --- a/keras/layers/rnn/cudnn_lstm.py +++ b/keras/layers/rnn/cudnn_lstm.py @@ -180,27 +180,38 @@ def _process_batch(self, inputs, initial_state): input_h = tf.expand_dims(input_h, axis=0) input_c = tf.expand_dims(input_c, axis=0) + # Prepare weights & biases + weights = [ + self.kernel[:, : self.units], + self.kernel[:, self.units : self.units * 2], + self.kernel[:, self.units * 2 : self.units * 3], + self.kernel[:, self.units * 3 :], + self.recurrent_kernel[:, : self.units], + self.recurrent_kernel[:, self.units : self.units * 2], + self.recurrent_kernel[:, self.units * 2 : self.units * 3], + self.recurrent_kernel[:, self.units * 3 :], + ] + + biases = [ + self.bias[: self.units], + self.bias[self.units : self.units * 2], + self.bias[self.units * 2 : self.units * 3], + self.bias[self.units * 3 : self.units * 4], + self.bias[self.units * 4 : self.units * 5], + self.bias[self.units * 5 : self.units * 6], + self.bias[self.units * 6 : self.units * 7], + self.bias[self.units * 7 :], + ] + + # If on ROCm, reorder weights/biases: [i, f, c, o] -> [i, f, o, c] + if tf.sysconfig.get_build_info()["is_rocm_build"]: + reorder_idx = (0, 1, 3, 2, 4, 5, 7, 6) + weights = [weights[i] for i in reorder_idx] + biases = [biases[i] for i in reorder_idx] + params = gru_lstm_utils.canonical_to_params( - weights=[ - self.kernel[:, : self.units], - self.kernel[:, self.units : self.units * 2], - self.kernel[:, self.units * 2 : self.units * 3], - self.kernel[:, self.units * 3 :], - self.recurrent_kernel[:, : self.units], - self.recurrent_kernel[:, self.units : self.units * 2], - self.recurrent_kernel[:, self.units * 2 : self.units * 3], - self.recurrent_kernel[:, self.units * 3 :], - ], - biases=[ - self.bias[: self.units], - self.bias[self.units : self.units * 2], - self.bias[self.units * 2 : self.units * 3], - self.bias[self.units * 3 : self.units * 4], - self.bias[self.units * 4 : self.units * 5], - self.bias[self.units * 5 : self.units * 6], - self.bias[self.units * 6 : self.units * 7], - self.bias[self.units * 7 :], - ], + weights=weights, + biases=biases, shape=self._vector_shape, ) @@ -217,6 +228,7 @@ def _process_batch(self, inputs, initial_state): if self.stateful or self.return_state: h = h[0] c = c[0] + if self.return_sequences: if self.time_major: output = outputs @@ -224,6 +236,7 @@ def _process_batch(self, inputs, initial_state): output = tf.transpose(outputs, perm=(1, 0, 2)) else: output = outputs[-1] + return output, [h, c] def get_config(self): diff --git a/shell/format.sh b/shell/format.sh index 234634b3727f..c4f390136192 100755 --- a/shell/format.sh +++ b/shell/format.sh @@ -1,4 +1,4 @@ #!/bin/bash isort --sl keras black --line-length 80 keras -flake8 keras +flake8 keras || echo "flake8 reports issues. Run lint.sh to enforce strict checks." diff --git a/shell/lint.sh b/shell/lint.sh index 0f06e65ca391..235797ed5ed9 100755 --- a/shell/lint.sh +++ b/shell/lint.sh @@ -1,4 +1,6 @@ #!/bin/bash +pip install --quiet --upgrade 'flake8>=5.0.0' + isort --check --sl -c keras if ! [ $? -eq 0 ] then