You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* Update nn.py
* Update nn.py
* Update nn.py
* Update nn.py
* Update nn.py
Corrected indentation in doc string
* Update nn.py
* Update random_grayscale.py
Fixed issue with passing a single image without batch dimension.
* Update keras/src/layers/preprocessing/image_preprocessing/random_grayscale.py
Co-authored-by: Jyotinder Singh <[email protected]>
* Update random_grayscale_test.py
Test case for unbatched inputs
* code reformat
* Update random_grayscale_test.py
Testcase for checking both unbatched and batched single image inputs.
* changed compute_output_spec
There was a bug, and it was causing cycle in graph.
* Update random_grayscale.py
removed the use of tree.map_structure
* Reapply "Fixed issue with dot_product_attention when using TPU. (#21254)" (#21329)
This reverts commit 81821e0.
* Improve error handling in _can_use_flash_attention for better debugging
Enhanced the _can_use_flash_attention function to provide more detailed
error messages when flash attention compatibility checks fail.
Changes:
- Replace generic exception catching with specific error propagation
- When raise_error=True, directly re-raise original exceptions from
check_layout() and check_is_flash_attention() functions
- Preserve detailed error context from JAX internal validation functions
- Maintain existing behavior when raise_error=False (returns False)
This improves debugging experience by surfacing specific technical details
about tensor layout incompatibilities, cuDNN version requirements, and
other flash attention compatibility issues.
Relates to keras-hub PR #2257 and addresses flash attention debugging needs.
* Revert "Improve error handling in _can_use_flash_attention for better debugging"
This reverts commit 7a0c547.
* Fix JAX API compatibility and improve error handling in `_can_use_flash_attention`
Changes:
- Add missing q_offsets=None and kv_offsets=None parameters to check_layout()
call to match updated JAX function signature
- Replace bare `except:` with `except Exception as e:` and `raise e` to
preserve detailed error messages from JAX validation functions
- Maintain existing fallback behavior when raise_error=False
This resolves compatibility issues with newer JAX versions and improves
debugging experience by surfacing specific technical details about
flash attention compatibility failures.
* Updated `dot_product_attention`
Simplified the check for `flasth_attention` by removing redundant checks that are already done in `_can_use_flash_attention`.
* Update nn.py
* Update nn.py
* Update image.py
* Update keras/src/backend/tensorflow/image.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
* Revert "Update keras/src/backend/tensorflow/image.py"
This reverts commit cb7e955.
* Update image.py
* Update image.py
---------
Co-authored-by: Jyotinder Singh <[email protected]>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
0 commit comments