Skip to content

Commit 18a18d3

Browse files
committed
Fix mixed-precision dtype mismatch in torch LSTM
1 parent f2fcac3 commit 18a18d3

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

keras/src/backend/torch/rnn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,12 @@ def lstm(
568568
if bias is not None:
569569
bias = convert_to_tensor(bias)
570570

571-
inputs = convert_to_tensor(inputs, dtype="float32")
572-
initial_state_h = convert_to_tensor(initial_state_h, dtype="float32")
573-
initial_state_c = convert_to_tensor(initial_state_c, dtype="float32")
571+
# Cast inputs/states to the kernel's dtype so integer inputs are promoted
572+
# to float and mixed-precision dtypes (e.g. float16) are respected.
573+
compute_dtype = kernel.dtype
574+
inputs = convert_to_tensor(inputs).to(compute_dtype)
575+
initial_state_h = convert_to_tensor(initial_state_h).to(compute_dtype)
576+
initial_state_c = convert_to_tensor(initial_state_c).to(compute_dtype)
574577

575578
# Preprocess for go_backwards by flipping the sequence
576579
if go_backwards:

0 commit comments

Comments
 (0)