Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ jobs:
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v2
uses: actions/cache@v4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove these 2 files from the PR

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted to the original state before the changes.

with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
python -m pip install --upgrade pip setuptools
echo "::set-output name=dir::$(pip cache dir)"
- name: pip cache
uses: actions/cache@v2
uses: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
Expand Down
45 changes: 29 additions & 16 deletions keras/layers/rnn/cudnn_gru.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,23 +152,34 @@ def _process_batch(self, inputs, initial_state):
input_h = initial_state[0]
input_h = tf.expand_dims(input_h, axis=0)


weights = [
self.kernel[:, : self.units],
self.kernel[:, self.units : self.units * 2],
self.kernel[:, self.units * 2 :],
self.recurrent_kernel[:, : self.units],
self.recurrent_kernel[:, self.units : self.units * 2],
self.recurrent_kernel[:, self.units * 2 :],
]

biases = [
self.bias[: self.units],
self.bias[self.units : self.units * 2],
self.bias[self.units * 2 : self.units * 3],
self.bias[self.units * 3 : self.units * 4],
self.bias[self.units * 4 : self.units * 5],
self.bias[self.units * 5 :],
]

if tf.sysconfig.get_build_info()["is_cuda_build"]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default layout is compatible with ROCm; here, the CUDA layout is adapted by checking whether the device is CUDA.

weights[0], weights[1] = weights[1], weights[0]
weights[3], weights[4] = weights[4], weights[3]
biases[0], biases[1] = biases[1], biases[0]
biases[3], biases[4] = biases[4], biases[3]

params = gru_lstm_utils.canonical_to_params(
weights=[
self.kernel[:, self.units : self.units * 2],
self.kernel[:, : self.units],
self.kernel[:, self.units * 2 :],
self.recurrent_kernel[:, self.units : self.units * 2],
self.recurrent_kernel[:, : self.units],
self.recurrent_kernel[:, self.units * 2 :],
],
biases=[
self.bias[self.units : self.units * 2],
self.bias[: self.units],
self.bias[self.units * 2 : self.units * 3],
self.bias[self.units * 4 : self.units * 5],
self.bias[self.units * 3 : self.units * 4],
self.bias[self.units * 5 :],
],
weights=weights,
biases=biases,
shape=self._vector_shape,
)

Expand All @@ -185,13 +196,15 @@ def _process_batch(self, inputs, initial_state):

if self.stateful or self.return_state:
h = h[0]

if self.return_sequences:
if self.time_major:
output = outputs
else:
output = tf.transpose(outputs, perm=(1, 0, 2))
else:
output = outputs[-1]

return output, [h]

def get_config(self):
Expand Down
53 changes: 33 additions & 20 deletions keras/layers/rnn/cudnn_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,27 +180,38 @@ def _process_batch(self, inputs, initial_state):
input_h = tf.expand_dims(input_h, axis=0)
input_c = tf.expand_dims(input_c, axis=0)

# Prepare weights & biases
weights = [
self.kernel[:, : self.units],
self.kernel[:, self.units : self.units * 2],
self.kernel[:, self.units * 2 : self.units * 3],
self.kernel[:, self.units * 3 :],
self.recurrent_kernel[:, : self.units],
self.recurrent_kernel[:, self.units : self.units * 2],
self.recurrent_kernel[:, self.units * 2 : self.units * 3],
self.recurrent_kernel[:, self.units * 3 :],
]

biases = [
self.bias[: self.units],
self.bias[self.units : self.units * 2],
self.bias[self.units * 2 : self.units * 3],
self.bias[self.units * 3 : self.units * 4],
self.bias[self.units * 4 : self.units * 5],
self.bias[self.units * 5 : self.units * 6],
self.bias[self.units * 6 : self.units * 7],
self.bias[self.units * 7 :],
]

# If on ROCm, reorder weights/biases: [i, f, c, o] -> [i, f, o, c]
if tf.sysconfig.get_build_info()["is_rocm_build"]:
reorder_idx = (0, 1, 3, 2, 4, 5, 7, 6)
weights = [weights[i] for i in reorder_idx]
biases = [biases[i] for i in reorder_idx]

params = gru_lstm_utils.canonical_to_params(
weights=[
self.kernel[:, : self.units],
self.kernel[:, self.units : self.units * 2],
self.kernel[:, self.units * 2 : self.units * 3],
self.kernel[:, self.units * 3 :],
self.recurrent_kernel[:, : self.units],
self.recurrent_kernel[:, self.units : self.units * 2],
self.recurrent_kernel[:, self.units * 2 : self.units * 3],
self.recurrent_kernel[:, self.units * 3 :],
],
biases=[
self.bias[: self.units],
self.bias[self.units : self.units * 2],
self.bias[self.units * 2 : self.units * 3],
self.bias[self.units * 3 : self.units * 4],
self.bias[self.units * 4 : self.units * 5],
self.bias[self.units * 5 : self.units * 6],
self.bias[self.units * 6 : self.units * 7],
self.bias[self.units * 7 :],
],
weights=weights,
biases=biases,
shape=self._vector_shape,
)

Expand All @@ -217,13 +228,15 @@ def _process_batch(self, inputs, initial_state):
if self.stateful or self.return_state:
h = h[0]
c = c[0]

if self.return_sequences:
if self.time_major:
output = outputs
else:
output = tf.transpose(outputs, perm=(1, 0, 2))
else:
output = outputs[-1]

return output, [h, c]

def get_config(self):
Expand Down
2 changes: 1 addition & 1 deletion shell/format.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#!/bin/bash
isort --sl keras
black --line-length 80 keras
flake8 keras
flake8 keras || echo "flake8 reports issues. Run lint.sh to enforce strict checks."
2 changes: 2 additions & 0 deletions shell/lint.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#!/bin/bash
pip install --quiet --upgrade 'flake8>=5.0.0'

isort --check --sl -c keras
if ! [ $? -eq 0 ]
then
Expand Down
Loading