Skip to content

Commit 2b394c6

Browse files
pctablet505JyotinderSinghgemini-code-assist[bot]
authored
Pctablet505 fix map coordinates (#21473)
* Update nn.py * Update nn.py * Update nn.py * Update nn.py * Update nn.py Corrected indentation in doc string * Update nn.py * Update random_grayscale.py Fixed issue with passing a single image without batch dimension. * Update keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py Co-authored-by: Jyotinder Singh <[email protected]> * Update random_grayscale_test.py Test case for unbatched inputs * code reformat * Update random_grayscale_test.py Testcase for checking both unbatched and batched single image inputs. * changed compute_output_spec There was a bug, and it was causing cycle in graph. * Update random_grayscale.py removed the use of tree.map_structure * Reapply "Fixed issue with dot_product_attention when using TPU. (#21254)" (#21329) This reverts commit 81821e0. * Improve error handling in _can_use_flash_attention for better debugging Enhanced the _can_use_flash_attention function to provide more detailed error messages when flash attention compatibility checks fail. Changes: - Replace generic exception catching with specific error propagation - When raise_error=True, directly re-raise original exceptions from check_layout() and check_is_flash_attention() functions - Preserve detailed error context from JAX internal validation functions - Maintain existing behavior when raise_error=False (returns False) This improves debugging experience by surfacing specific technical details about tensor layout incompatibilities, cuDNN version requirements, and other flash attention compatibility issues. Relates to keras-hub PR #2257 and addresses flash attention debugging needs. * Revert "Improve error handling in _can_use_flash_attention for better debugging" This reverts commit 7a0c547. * Fix JAX API compatibility and improve error handling in `_can_use_flash_attention` Changes: - Add missing q_offsets=None and kv_offsets=None parameters to check_layout() call to match updated JAX function signature - Replace bare `except:` with `except Exception as e:` and `raise e` to preserve detailed error messages from JAX validation functions - Maintain existing fallback behavior when raise_error=False This resolves compatibility issues with newer JAX versions and improves debugging experience by surfacing specific technical details about flash attention compatibility failures. * Updated `dot_product_attention` Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`. * Update nn.py * Update nn.py * Update image.py * Update keras/src/backend/tensorflow/image.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Revert "Update keras/src/backend/tensorflow/image.py" This reverts commit cb7e955. * Update image.py * Update image.py --------- Co-authored-by: Jyotinder Singh <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ec83f72 commit 2b394c6

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

keras/src/backend/tensorflow/image.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -706,15 +706,17 @@ def process_coordinates(coords, size):
706706

707707
gathered = tf.transpose(tf.gather_nd(input_arr, indices))
708708

709+
# Cast to computation dtype early to avoid type issues
710+
dtype = weights[0].dtype
711+
gathered = tf.cast(gathered, dtype)
712+
gathered = tf.cast(gathered, weights[0].dtype)
713+
709714
if fill_mode == "constant":
710715
all_valid = tf.reduce_all(validities, axis=0)
711-
gathered = tf.where(all_valid, gathered, fill_value)
716+
fill_value_typed = tf.cast(fill_value, dtype)
717+
gathered = tf.where(all_valid, gathered, fill_value_typed)
712718

713-
contribution = gathered
714-
outputs.append(
715-
functools.reduce(operator.mul, weights)
716-
* tf.cast(contribution, weights[0].dtype)
717-
)
719+
outputs.append(functools.reduce(operator.mul, weights) * gathered)
718720

719721
result = functools.reduce(operator.add, outputs)
720722

0 commit comments

Comments
 (0)