Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions ext/MatrixAlgebraKitChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,22 @@ function ChainRulesCore.rrule(::typeof(right_polar!), A, PWᴴ, alg)
return PWᴴ, right_polar_pullback
end

function ChainRulesCore.rrule(::typeof(project_hermitian), A, alg)
Aₕ = project_hermitian(A, alg)
function project_hermitian_pullback(ΔAₕ)
ΔA = project_hermitian(unthunk(ΔAₕ))
return NoTangent(), ΔA, NoTangent()
end
return Aₕ, project_hermitian_pullback
end

function ChainRulesCore.rrule(::typeof(project_antihermitian), A, alg)
Aₐ = project_antihermitian(A, alg)
function project_antihermitian_pullback(ΔAₐ)
ΔA = project_antihermitian(unthunk(ΔAₐ))
return NoTangent(), ΔA, NoTangent()
end
return Aₐ, project_antihermitian_pullback
end

end
47 changes: 47 additions & 0 deletions ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -778,4 +778,51 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_trunc_no_error)}, A_dA::CoDual, al
return USVᴴtrunc_dUSVᴴtrunc, svd_trunc_adjoint
end

# single-output projections: project_hermitian!, project_antihermitian!
for (f!, f, adj) in (
(:project_hermitian!, :project_hermitian, :project_hermitian_adjoint),
(:project_antihermitian!, :project_antihermitian, :project_antihermitian_adjoint),
)
@eval begin
@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
arg, darg = A_dA === arg_darg ? (A, dA) : arrayify(arg_darg)

# don't need to copy/restore A since projections don't mutate input
argc = copy(arg)
arg = $f!(A, arg, Mooncake.primal(alg_dalg))

function $adj(::NoRData)
$f!(darg)
if dA !== darg
dA .+= darg
zero!(darg)
end
copy!(arg, argc)
return ntuple(Returns(NoRData()), 4)
end

return arg_darg, $adj
end

@is_primitive DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm}
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm})
A, dA = arrayify(A_dA)
output = $f(A, Mooncake.primal(alg_dalg))
output_doutput = Mooncake.zero_fcodual(output)

doutput = last(arrayify(output_doutput))
function $adj(::NoRData)
# TODO: need accumulating projection to avoid intermediate here
dA .+= $f(doutput)
zero!(doutput)
return ntuple(Returns(NoRData()), 3)
end

return output_doutput, $adj
end
end
end

end
6 changes: 3 additions & 3 deletions src/implementations/projections.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,11 @@ end

function project_hermitian_native!(A::Diagonal, B::Diagonal, ::Val{anti}; kwargs...) where {anti}
if anti
diagview(A) .= _imimag.(diagview(B))
diagview(B) .= _imimag.(diagview(A))
else
diagview(A) .= real.(diagview(B))
diagview(B) .= real.(diagview(A))
end
return A
return B
end

function project_hermitian_native!(A::AbstractMatrix, B::AbstractMatrix, anti::Val; blocksize = 32)
Expand Down
4 changes: 2 additions & 2 deletions src/pullbacks/polar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
!iszerotangent(ΔW) && mul!(M, W', ΔW, 1, 1)
!iszerotangent(ΔP) && mul!(M, ΔP, P, -1, 1)
C = _sylvester(P, P, M' - M)
C .+= ΔP
!iszerotangent(ΔP) && (C .+= ΔP)
ΔA = mul!(ΔA, W, C, 1, 1)
if !iszerotangent(ΔW)
ΔWP = ΔW / P
Expand Down Expand Up @@ -47,7 +47,7 @@ function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...
!iszerotangent(ΔWᴴ) && mul!(M, ΔWᴴ, Wᴴ', 1, 1)
!iszerotangent(ΔP) && mul!(M, P, ΔP, -1, 1)
C = _sylvester(P, P, M' - M)
C .+= ΔP
!iszerotangent(ΔP) && (C .+= ΔP)
ΔA = mul!(ΔA, C, Wᴴ, 1, 1)
if !iszerotangent(ΔWᴴ)
PΔWᴴ = P \ ΔWᴴ
Expand Down
21 changes: 21 additions & 0 deletions test/mooncake/projections.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
using MatrixAlgebraKit
using Test
using LinearAlgebra: Diagonal
using CUDA, AMDGPU

BLASFloats = (Float32, ComplexF64) # full suite is too expensive on CI
GenericFloats = ()
@isdefined(TestSuite) || include("../testsuite/TestSuite.jl")
using .TestSuite

is_buildkite = get(ENV, "BUILDKITE", "false") == "true"

m = 19
for T in (BLASFloats..., GenericFloats...)
TestSuite.seed_rng!(123)
atol = rtol = m * m * TestSuite.precision(T)
if !is_buildkite
TestSuite.test_mooncake_projections(T, (m, m); atol, rtol)
TestSuite.test_mooncake_projections(Diagonal{T, Vector{T}}, (m, m); atol, rtol)
end
end
1 change: 1 addition & 0 deletions test/testsuite/TestSuite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ include("mooncake/eigh.jl")
include("mooncake/svd.jl")
include("mooncake/polar.jl")
include("mooncake/orthnull.jl")
include("mooncake/projections.jl")

include("enzyme.jl")
include("chainrules.jl")
Expand Down
23 changes: 23 additions & 0 deletions test/testsuite/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ function test_chainrules(T::Type, sz; kwargs...)
test_chainrules_svd(T, sz; kwargs...)
test_chainrules_polar(T, sz; kwargs...)
test_chainrules_orthnull(T, sz; kwargs...)
test_chainrules_projections(T, sz; kwargs...)
end
end

Expand Down Expand Up @@ -587,3 +588,25 @@ function test_chainrules_orthnull(
)
end
end

function test_chainrules_projections(
T::Type, sz;
atol::Real = 0, rtol::Real = precision(T),
kwargs...
)
summary_str = testargs_summary(T, sz)
return @testset "Projections Chainrules AD rules $summary_str" begin
A = instantiate_matrix(T, sz)
m, n = size(A)
if m == n
@testset "project_hermitian" begin
alg = MatrixAlgebraKit.default_hermitian_algorithm(A)
test_rrule(project_hermitian, A, alg; atol, rtol)
end
@testset "project_antihermitian" begin
alg = MatrixAlgebraKit.default_hermitian_algorithm(A)
test_rrule(project_antihermitian, A, alg; atol, rtol)
end
end
end
end
69 changes: 69 additions & 0 deletions test/testsuite/mooncake/projections.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
"""
test_mooncake_projections(T, sz; kwargs...)

Run all Mooncake AD tests for hermitian and anti-hermitian projections of element type `T`
and size `sz`.
"""
function test_mooncake_projections(T::Type, sz; kwargs...)
summary_str = testargs_summary(T, sz)
return @testset "Mooncake projection $summary_str" begin
test_mooncake_project_hermitian(T, sz; kwargs...)
test_mooncake_project_antihermitian(T, sz; kwargs...)
end
end

"""
test_mooncake_project_hermitian(T, sz; rng, atol, rtol)

Test the Mooncake reverse-mode AD rule for `project_hermitian` and its in-place variant.
"""
function test_mooncake_project_hermitian(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T)
)
return @testset "project_hermitian" begin
A = instantiate_matrix(T, sz)
B = instantiate_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A)
Mooncake.TestUtils.test_rule(
rng, project_hermitian, A, alg;
mode = Mooncake.ReverseMode, atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, project_hermitian!, A, A, alg;
mode = Mooncake.ReverseMode, atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, project_hermitian!, A, B, alg;
mode = Mooncake.ReverseMode, atol, rtol
)
end
end

"""
test_mooncake_project_antihermitian(T, sz; rng, atol, rtol)

Test the Mooncake reverse-mode AD rule for `project_antihermitian` and its in-place variant.
"""
function test_mooncake_project_antihermitian(
T, sz;
rng = Random.default_rng(), atol::Real = 0, rtol::Real = precision(T)
)
return @testset "project_antihermitian" begin
A = instantiate_matrix(T, sz)
B = instantiate_matrix(T, sz)
alg = MatrixAlgebraKit.select_algorithm(project_hermitian, A)
Mooncake.TestUtils.test_rule(
rng, project_antihermitian, A, alg;
mode = Mooncake.ReverseMode, atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, project_antihermitian!, A, A, alg;
mode = Mooncake.ReverseMode, atol, rtol
)
Mooncake.TestUtils.test_rule(
rng, project_antihermitian!, A, B, alg;
mode = Mooncake.ReverseMode, atol, rtol
)
end
end
Loading