File tree Expand file tree Collapse file tree 1 file changed +6
-3
lines changed
Expand file tree Collapse file tree 1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -568,9 +568,12 @@ def lstm(
568568 if bias is not None :
569569 bias = convert_to_tensor (bias )
570570
571- inputs = convert_to_tensor (inputs , dtype = "float32" )
572- initial_state_h = convert_to_tensor (initial_state_h , dtype = "float32" )
573- initial_state_c = convert_to_tensor (initial_state_c , dtype = "float32" )
571+ # Cast inputs/states to the kernel's dtype so integer inputs are promoted
572+ # to float and mixed-precision dtypes (e.g. float16) are respected.
573+ compute_dtype = kernel .dtype
574+ inputs = convert_to_tensor (inputs ).to (compute_dtype )
575+ initial_state_h = convert_to_tensor (initial_state_h ).to (compute_dtype )
576+ initial_state_c = convert_to_tensor (initial_state_c ).to (compute_dtype )
574577
575578 # Preprocess for go_backwards by flipping the sequence
576579 if go_backwards :
You can’t perform that action at this time.
0 commit comments