Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2694,6 +2694,15 @@ def compute_output_spec(self, x):
f"`x` is of shape {x.shape}."
)

ndim = len(x_shape)
ax1 = self.axis1 if self.axis1 >= 0 else self.axis1 + ndim
ax2 = self.axis2 if self.axis2 >= 0 else self.axis2 + ndim
Comment on lines +2698 to +2699
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use ax1 = canonicalize_axis(self.axis1, ndim) which does some validation.

if ax1 == ax2:
raise ValueError(
"`axis1` and `axis2` cannot be the same. "
f"Received: axis1={self.axis1}, axis2={self.axis2}"
)

shape_2d = [x_shape[self.axis1], x_shape[self.axis2]]
x_shape[self.axis1] = -1
x_shape[self.axis2] = -1
Expand Down Expand Up @@ -2767,6 +2776,14 @@ def diagonal(x, offset=0, axis1=0, axis2=1):
axis1=axis1,
axis2=axis2,
).symbolic_call(x)
x_ndim = len(x.shape)
ax1 = axis1 if axis1 >= 0 else axis1 + x_ndim
ax2 = axis2 if axis2 >= 0 else axis2 + x_ndim
if ax1 == ax2:
raise ValueError(
"`axis1` and `axis2` cannot be the same. "
f"Received: axis1={axis1}, axis2={axis2}"
)
Comment on lines +2779 to +2786
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current pattern used is to not have any code in these functions, they're supposed to delegate to the backend functions.

So please move to whichever backend doesn't throw a proper exception (tensorflow is sounds like).

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, never mind. Validation is fine. But:

  • use canonicalize_axis here
  • move it before the if any_symbolic_tensor
  • remove the other one in compute_output_shape, it's not longer needed

return backend.numpy.diagonal(
x,
offset=offset,
Expand Down
Loading