Skip to content

Commit 693147d

Browse files
committed
WIP: ProjectTo type piracy AD fix
1 parent 57be075 commit 693147d

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

src/ApproximateGPs.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,10 @@ include("deprecations.jl")
2323

2424
include("TestUtils.jl")
2525

26+
import ChainRulesCore: ProjectTo, Tangent
27+
using PDMats: ScalMat
28+
ProjectTo(x::T) where T <: ScalMat = ProjectTo{T}(; dim=x.dim, value=ProjectTo(x.value))
29+
(pr::ProjectTo{<:ScalMat})(dx::ScalMat) = ScalMat(pr.dim, pr.value(dx.value))
30+
(pr::ProjectTo{<:ScalMat})(dx::Tangent{<:ScalMat}) = ScalMat(pr.dim, pr.value(dx.value))
31+
2632
end

0 commit comments

Comments
 (0)