From e5c034e0da35b43e184f887ddc22c81dc62f8625 Mon Sep 17 00:00:00 2001 From: JoeNan1 <2268346832@qq.com> Date: Tue, 1 Jul 2025 10:37:14 +0800 Subject: [PATCH 1/7] Enable ROCm devices to correctly use keras.layers.CuDNNGRU and keras.layers.CuDNNLSTM --- keras/layers/rnn/cudnn_gru.py | 45 +++++++++++++++++++---------- keras/layers/rnn/cudnn_lstm.py | 53 +++++++++++++++++++++------------- 2 files changed, 62 insertions(+), 36 deletions(-) diff --git a/keras/layers/rnn/cudnn_gru.py b/keras/layers/rnn/cudnn_gru.py index 45c7c91d53e3..e666cf1ea9ff 100644 --- a/keras/layers/rnn/cudnn_gru.py +++ b/keras/layers/rnn/cudnn_gru.py @@ -152,23 +152,34 @@ def _process_batch(self, inputs, initial_state): input_h = initial_state[0] input_h = tf.expand_dims(input_h, axis=0) + + weights = [ + self.kernel[:, : self.units], + self.kernel[:, self.units : self.units * 2], + self.kernel[:, self.units * 2 :], + self.recurrent_kernel[:, : self.units], + self.recurrent_kernel[:, self.units : self.units * 2], + self.recurrent_kernel[:, self.units * 2 :], + ] + + biases = [ + self.bias[: self.units], + self.bias[self.units : self.units * 2], + self.bias[self.units * 2 : self.units * 3], + self.bias[self.units * 3 : self.units * 4], + self.bias[self.units * 4 : self.units * 5], + self.bias[self.units * 5 :], + ] + + if tf.sysconfig.get_build_info()["is_cuda_build"]: + weights[0], weights[1] = weights[1], weights[0] + weights[3], weights[4] = weights[4], weights[3] + biases[0], biases[1] = biases[1], biases[0] + biases[3], biases[4] = biases[4], biases[3] + params = gru_lstm_utils.canonical_to_params( - weights=[ - self.kernel[:, self.units : self.units * 2], - self.kernel[:, : self.units], - self.kernel[:, self.units * 2 :], - self.recurrent_kernel[:, self.units : self.units * 2], - self.recurrent_kernel[:, : self.units], - self.recurrent_kernel[:, self.units * 2 :], - ], - biases=[ - self.bias[self.units : self.units * 2], - self.bias[: self.units], - self.bias[self.units * 2 : self.units * 3], - self.bias[self.units * 4 : self.units * 5], - self.bias[self.units * 3 : self.units * 4], - self.bias[self.units * 5 :], - ], + weights=weights, + biases=biases, shape=self._vector_shape, ) @@ -185,6 +196,7 @@ def _process_batch(self, inputs, initial_state): if self.stateful or self.return_state: h = h[0] + if self.return_sequences: if self.time_major: output = outputs @@ -192,6 +204,7 @@ def _process_batch(self, inputs, initial_state): output = tf.transpose(outputs, perm=(1, 0, 2)) else: output = outputs[-1] + return output, [h] def get_config(self): diff --git a/keras/layers/rnn/cudnn_lstm.py b/keras/layers/rnn/cudnn_lstm.py index 69ae8e96af6b..e07909991cee 100644 --- a/keras/layers/rnn/cudnn_lstm.py +++ b/keras/layers/rnn/cudnn_lstm.py @@ -180,27 +180,38 @@ def _process_batch(self, inputs, initial_state): input_h = tf.expand_dims(input_h, axis=0) input_c = tf.expand_dims(input_c, axis=0) + # Prepare weights & biases + weights = [ + self.kernel[:, : self.units], + self.kernel[:, self.units : self.units * 2], + self.kernel[:, self.units * 2 : self.units * 3], + self.kernel[:, self.units * 3 :], + self.recurrent_kernel[:, : self.units], + self.recurrent_kernel[:, self.units : self.units * 2], + self.recurrent_kernel[:, self.units * 2 : self.units * 3], + self.recurrent_kernel[:, self.units * 3 :], + ] + + biases = [ + self.bias[: self.units], + self.bias[self.units : self.units * 2], + self.bias[self.units * 2 : self.units * 3], + self.bias[self.units * 3 : self.units * 4], + self.bias[self.units * 4 : self.units * 5], + self.bias[self.units * 5 : self.units * 6], + self.bias[self.units * 6 : self.units * 7], + self.bias[self.units * 7 :], + ] + + # If on ROCm, reorder weights/biases: [i, f, c, o] -> [i, f, o, c] + if tf.sysconfig.get_build_info()["is_rocm_build"]: + reorder_idx = (0, 1, 3, 2, 4, 5, 7, 6) + weights = [weights[i] for i in reorder_idx] + biases = [biases[i] for i in reorder_idx] + params = gru_lstm_utils.canonical_to_params( - weights=[ - self.kernel[:, : self.units], - self.kernel[:, self.units : self.units * 2], - self.kernel[:, self.units * 2 : self.units * 3], - self.kernel[:, self.units * 3 :], - self.recurrent_kernel[:, : self.units], - self.recurrent_kernel[:, self.units : self.units * 2], - self.recurrent_kernel[:, self.units * 2 : self.units * 3], - self.recurrent_kernel[:, self.units * 3 :], - ], - biases=[ - self.bias[: self.units], - self.bias[self.units : self.units * 2], - self.bias[self.units * 2 : self.units * 3], - self.bias[self.units * 3 : self.units * 4], - self.bias[self.units * 4 : self.units * 5], - self.bias[self.units * 5 : self.units * 6], - self.bias[self.units * 6 : self.units * 7], - self.bias[self.units * 7 :], - ], + weights=weights, + biases=biases, shape=self._vector_shape, ) @@ -217,6 +228,7 @@ def _process_batch(self, inputs, initial_state): if self.stateful or self.return_state: h = h[0] c = c[0] + if self.return_sequences: if self.time_major: output = outputs @@ -224,6 +236,7 @@ def _process_batch(self, inputs, initial_state): output = tf.transpose(outputs, perm=(1, 0, 2)) else: output = outputs[-1] + return output, [h, c] def get_config(self): From 36132e1b9314979e1ca5f7d3426144d435940e82 Mon Sep 17 00:00:00 2001 From: JoeNan1 <89842438+JoeNan1@users.noreply.github.com> Date: Tue, 1 Jul 2025 11:00:59 +0800 Subject: [PATCH 2/7] Update format.yml --- .github/workflows/format.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 68e0256ba2b3..894ca2725880 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -20,7 +20,7 @@ jobs: python -m pip install --upgrade pip setuptools echo "::set-output name=dir::$(pip cache dir)" - name: pip cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} From 3704550d5bddd70997466c0ec0eb36495fbe996b Mon Sep 17 00:00:00 2001 From: JoeNan1 <89842438+JoeNan1@users.noreply.github.com> Date: Tue, 1 Jul 2025 11:01:24 +0800 Subject: [PATCH 3/7] Update lint.yml --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 66388041bc5b..25391159ea76 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -21,7 +21,7 @@ jobs: python -m pip install --upgrade pip setuptools echo "::set-output name=dir::$(pip cache dir)" - name: pip cache - uses: actions/cache@v2 + uses: actions/cache@v4 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} From a765a83c6a71fe0954a5c9dad3235f5b3edd98d3 Mon Sep 17 00:00:00 2001 From: JoeNan1 <89842438+JoeNan1@users.noreply.github.com> Date: Mon, 14 Jul 2025 09:51:32 +0800 Subject: [PATCH 4/7] Update lint.sh --- shell/lint.sh | 2 ++ 1 file changed, 2 insertions(+) diff --git a/shell/lint.sh b/shell/lint.sh index 0f06e65ca391..235797ed5ed9 100755 --- a/shell/lint.sh +++ b/shell/lint.sh @@ -1,4 +1,6 @@ #!/bin/bash +pip install --quiet --upgrade 'flake8>=5.0.0' + isort --check --sl -c keras if ! [ $? -eq 0 ] then From 2b58990a3458c21a5a1a69e36bfd5b474b544a77 Mon Sep 17 00:00:00 2001 From: JoeNan1 <89842438+JoeNan1@users.noreply.github.com> Date: Mon, 14 Jul 2025 09:59:00 +0800 Subject: [PATCH 5/7] Update format.sh --- shell/format.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/shell/format.sh b/shell/format.sh index 234634b3727f..c4f390136192 100755 --- a/shell/format.sh +++ b/shell/format.sh @@ -1,4 +1,4 @@ #!/bin/bash isort --sl keras black --line-length 80 keras -flake8 keras +flake8 keras || echo "flake8 reports issues. Run lint.sh to enforce strict checks." From c854a83a5fb9649745c29f972a548f1fe834ccf5 Mon Sep 17 00:00:00 2001 From: JoeNan1 <89842438+JoeNan1@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:29:59 +0800 Subject: [PATCH 6/7] Update format.yml --- .github/workflows/format.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml index 894ca2725880..68e0256ba2b3 100644 --- a/.github/workflows/format.yml +++ b/.github/workflows/format.yml @@ -20,7 +20,7 @@ jobs: python -m pip install --upgrade pip setuptools echo "::set-output name=dir::$(pip cache dir)" - name: pip cache - uses: actions/cache@v4 + uses: actions/cache@v2 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} From 6243be1838751b7844d1fabcb8d47b7a9a9bc0ba Mon Sep 17 00:00:00 2001 From: JoeNan1 <89842438+JoeNan1@users.noreply.github.com> Date: Tue, 15 Jul 2025 11:30:23 +0800 Subject: [PATCH 7/7] Update lint.yml --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 25391159ea76..66388041bc5b 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -21,7 +21,7 @@ jobs: python -m pip install --upgrade pip setuptools echo "::set-output name=dir::$(pip cache dir)" - name: pip cache - uses: actions/cache@v4 + uses: actions/cache@v2 with: path: ${{ steps.pip-cache.outputs.dir }} key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}