Skip to content

Commit 08b8321

Browse files
author
Katharine Hyatt
committed
Restore argmax and fix AD test
1 parent e15ccf8 commit 08b8321

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

src/auxiliary/linalg.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,5 @@ end
5454

5555
safesign(s::Real) = ifelse(s < zero(s), -one(s), +one(s))
5656
safesign(s::Complex) = ifelse(iszero(s), one(s), s / abs(s))
57+
58+
_argmax(f, domain) = argmax(f, domain)

test/ad.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ Vlist = ((ℂ^2, (ℂ^3)', ℂ^3, ℂ^2, (ℂ^2)'),
447447
T <: Complex && remove_svdgauge_depence!(ΔU, ΔV, U, S, V)
448448
test_rrule(tsvd, C; atol, output_tangent=(ΔU, ΔS, ΔV))
449449

450-
c, = TensorKit.MatrixAlgebra._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])),
450+
c, = TensorKit._argmax(x -> sqrt(dim(x[1])) * maximum(diag(x[2])),
451451
blocks(S))
452452
trunc = truncdim(round(Int, 2 * dim(c)))
453453
U, S, V = tsvd(C; trunc)

0 commit comments

Comments
 (0)