Skip to content

Fix: keras.ops.argpartition failed when axis==None#22560

Open
maitry63 wants to merge 6 commits intokeras-team:masterfrom
maitry63:argpart_error_fix
Open

Fix: keras.ops.argpartition failed when axis==None#22560
maitry63 wants to merge 6 commits intokeras-team:masterfrom
maitry63:argpart_error_fix

Conversation

@maitry63
Copy link
Copy Markdown
Contributor

This PR fixes keras.ops.argpartition fail with axis==None
Fixes: #22537

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements support for axis=None in the argpartition operation, which now correctly flattens the input tensor before partitioning. It also includes a fix in the TensorFlow backend to handle scalar updates in tensor_scatter_nd_update and improves the error message formatting for the kth argument. A review comment suggests moving a local import math to the top of the file to comply with PEP 8 standards.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 27, 2026

Codecov Report

❌ Patch coverage is 66.66667% with 1 line in your changes missing coverage. Please review.
✅ Project coverage is 62.03%. Comparing base (a2e97e1) to head (72210d6).
⚠️ Report is 42 commits behind head on master.

Files with missing lines Patch % Lines
keras/src/backend/torch/numpy.py 0.00% 1 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (a2e97e1) and HEAD (72210d6). Click for more details.

HEAD has 8 uploads less than BASE
Flag BASE (a2e97e1) HEAD (72210d6)
keras 6 2
keras-jax 2 1
keras-numpy 1 0
keras-openvino 1 0
keras-torch 1 0
Additional details and impacted files
@@             Coverage Diff             @@
##           master   #22560       +/-   ##
===========================================
- Coverage   83.26%   62.03%   -21.24%     
===========================================
  Files         596      596               
  Lines       67828    68091      +263     
  Branches    10562    10608       +46     
===========================================
- Hits        56480    42237    -14243     
- Misses       8605    23557    +14952     
+ Partials     2743     2297      -446     
Flag Coverage Δ
keras 61.98% <66.66%> (-21.11%) ⬇️
keras-jax 24.78% <0.00%> (-35.05%) ⬇️
keras-numpy ?
keras-openvino ?
keras-tensorflow 61.04% <66.66%> (-0.10%) ⬇️
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@keerthanakadiri
Copy link
Copy Markdown
Contributor

keerthanakadiri commented Mar 27, 2026

Code Review

This pull request implements support for axis=None in the argpartition operation, which now correctly flattens the input tensor before partitioning. It also includes a fix in the TensorFlow backend to handle scalar updates in tensor_scatter_nd_update and improves the error message formatting for the kth argument. A review comment suggests moving a local import math to the top of the file to comply with PEP 8 standards.

Hi @maitry63, Can you please follow the suggestions.And provide more context about the problem that this PR solves.

@keerthanakadiri keerthanakadiri added the stat:awaiting keras-eng Awaiting response from Keras engineer label Mar 27, 2026
Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for looking into this!

if dim is None
else dim
(
tf.maximum(x_original_shape[i], indices_original_shape[i])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spurious reformatting change. Revert.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reverted unnecessary reformatting in tensorflow backend.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still here.

Comment on lines +3715 to +3716
tf.reshape(updates, [-1]) if tf.rank(updates) == 0 else updates,
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change needed?

Copy link
Copy Markdown
Contributor Author

@maitry63 maitry63 Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed to use updates = tf.zeros(tf.shape(indices)[0]) to avoid rank-0 errors.

Comment on lines +9228 to +9232
x = backend.convert_to_tensor(x)

if axis is None:
x = backend.numpy.reshape(x, [-1])
return backend.numpy.argpartition(x, kth, axis=0)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not how this should be fixed. There shouldn't be any code here. Each backend has to support this correctly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed flattening from ops layer; backend handles flattening for eager tensors; symbolic tensors preserve shape and handled axis=None properly in backend functions, removed unnecessary ops-layer reshaping.

@hertschuh hertschuh added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels Mar 31, 2026
Update tests to pass for both symbolic and eager tensors
Copy link
Copy Markdown
Collaborator

@hertschuh hertschuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you look at the failing tests? Thanks!

if dim is None
else dim
(
tf.maximum(x_original_shape[i], indices_original_shape[i])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's still here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

keras.ops.argpartition failed when axis==None

5 participants