Fix: keras.ops.argpartition failed when axis==None#22560
Fix: keras.ops.argpartition failed when axis==None#22560maitry63 wants to merge 6 commits intokeras-team:masterfrom
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements support for axis=None in the argpartition operation, which now correctly flattens the input tensor before partitioning. It also includes a fix in the TensorFlow backend to handle scalar updates in tensor_scatter_nd_update and improves the error message formatting for the kth argument. A review comment suggests moving a local import math to the top of the file to comply with PEP 8 standards.
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #22560 +/- ##
===========================================
- Coverage 83.26% 62.03% -21.24%
===========================================
Files 596 596
Lines 67828 68091 +263
Branches 10562 10608 +46
===========================================
- Hits 56480 42237 -14243
- Misses 8605 23557 +14952
+ Partials 2743 2297 -446
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Hi @maitry63, Can you please follow the suggestions.And provide more context about the problem that this PR solves. |
hertschuh
left a comment
There was a problem hiding this comment.
Thanks for looking into this!
| if dim is None | ||
| else dim | ||
| ( | ||
| tf.maximum(x_original_shape[i], indices_original_shape[i]) |
There was a problem hiding this comment.
Spurious reformatting change. Revert.
There was a problem hiding this comment.
Reverted unnecessary reformatting in tensorflow backend.
| tf.reshape(updates, [-1]) if tf.rank(updates) == 0 else updates, | ||
| ) |
There was a problem hiding this comment.
fixed to use updates = tf.zeros(tf.shape(indices)[0]) to avoid rank-0 errors.
keras/src/ops/numpy.py
Outdated
| x = backend.convert_to_tensor(x) | ||
|
|
||
| if axis is None: | ||
| x = backend.numpy.reshape(x, [-1]) | ||
| return backend.numpy.argpartition(x, kth, axis=0) |
There was a problem hiding this comment.
This is not how this should be fixed. There shouldn't be any code here. Each backend has to support this correctly.
There was a problem hiding this comment.
Removed flattening from ops layer; backend handles flattening for eager tensors; symbolic tensors preserve shape and handled axis=None properly in backend functions, removed unnecessary ops-layer reshaping.
Update tests to pass for both symbolic and eager tensors
hertschuh
left a comment
There was a problem hiding this comment.
Can you look at the failing tests? Thanks!
| if dim is None | ||
| else dim | ||
| ( | ||
| tf.maximum(x_original_shape[i], indices_original_shape[i]) |
This PR fixes
keras.ops.argpartitionfail with axis==NoneFixes: #22537