Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions ext/TensorOperationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ trivtuple(N) = ntuple(identity, N)
@non_differentiable TensorOperations.tensorcontract_structure(args...)
@non_differentiable TensorOperations.tensorcontract_type(args...)
@non_differentiable TensorOperations.tensoralloc_contract(args...)
@non_differentiable Base.promote_op(args...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if your example also works if this isn't included? We can't really do this here since that is type piracy, so we would have to find an alternative way around this...

Suggested change
@non_differentiable Base.promote_op(args...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, the example fails if this is removed. To avoid type piracy, would this alternative be acceptable?

@non_differentiable TensorOperations.promote_contract(args...)
@non_differentiable TensorOperations.promote_add(args...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is completely fine, thank you!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Updated in the latest commit.


# Cannot free intermediate tensors when using AD
# Thus we change the forward passes: `istemp=false` and `tensorfree!` is a no-op
Expand All @@ -38,11 +39,11 @@ function ChainRulesCore.rrule(
return output, tensoralloc_pullback
end

# TODO: possibly use the non-inplace functions, to avoid depending on Base.copy
function ChainRulesCore.rrule(::typeof(tensorscalar), C)
projectC = ProjectTo(C)
function tensorscalar_pullback(Δc)
ΔC = TensorOperations.tensoralloc(typeof(C), TensorOperations.tensorstructure(C))
return NoTangent(), fill!(ΔC, unthunk(Δc))
_Δc = unthunk(Δc)
return NoTangent(), projectC(_Δc)
end
return tensorscalar(C), tensorscalar_pullback
end
Expand Down Expand Up @@ -95,7 +96,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
ipA = invperm(linearize(pA))
_dA = zerovector(A, VectorInterface.promote_add(ΔC, α))
_dA = tensoradd!(_dA, ΔC, (ipA, ()), conjA, conjA ? α : conj(α), Zero(), ba...)
return projectA(_dA)
projectA(_dA)
end
dα = @thunk let
_dα = tensorscalar(
Expand All @@ -105,7 +106,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
((), ()), One(), ba...
)
)
return projectα(_dα)
projectα(_dα)
end
dβ = @thunk let
# TODO: consider using `inner`
Expand All @@ -116,7 +117,7 @@ function _rrule_tensoradd!(C, A, pA, conjA, α, β, ba)
((), ()), One(), ba...
)
)
return projectβ(_dβ)
projectβ(_dβ)
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), dα, dβ, dba...
Expand Down Expand Up @@ -194,7 +195,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
ipA,
conjA ? α : conj(α), Zero(), ba...
)
return projectA(_dA)
projectA(_dA)
end
dB = @thunk let
ipB = (invperm(linearize(pB)), ())
Expand All @@ -208,7 +209,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
ipB,
conjB ? α : conj(α), Zero(), ba...
)
return projectB(_dB)
projectB(_dB)
end
dα = @thunk let
C_αβ = tensorcontract(A, pA, conjA, B, pB, conjB, pAB, One(), ba...)
Expand All @@ -220,7 +221,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
((), ()), One(), ba...
)
)
return projectα(_dα)
projectα(_dα)
end
dβ = @thunk let
# TODO: consider using `inner`
Expand All @@ -231,7 +232,7 @@ function _rrule_tensorcontract!(C, A, pA, conjA, B, pB, conjB, pAB, α, β, ba)
((), ()), One(), ba...
)
)
return projectβ(_dβ)
projectβ(_dβ)
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC,
Expand Down Expand Up @@ -283,7 +284,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
dA = @thunk let
ip = invperm((linearize(p)..., q[1]..., q[2]...))
Es = map(q[1], q[2]) do i1, i2
return one(
one(
TensorOperations.tensoralloc_add(
scalartype(A), A, ((i1,), (i2,)), conjA
)
Expand All @@ -297,7 +298,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
(ip, ()),
conjA ? α : conj(α), Zero(), ba...
)
return projectA(_dA)
projectA(_dA)
end
dα = @thunk let
C_αβ = tensortrace(A, p, q, false, One(), ba...)
Expand All @@ -309,7 +310,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
((), ()), One(), ba...
)
)
return projectα(_dα)
projectα(_dα)
end
dβ = @thunk let
_dβ = tensorscalar(
Expand All @@ -319,7 +320,7 @@ function _rrule_tensortrace!(C, A, p, q, conjA, α, β, ba)
((), ()), One(), ba...
)
)
return projectβ(_dβ)
projectβ(_dβ)
end
dba = map(_ -> NoTangent(), ba)
return NoTangent(), dC, dA, NoTangent(), NoTangent(), NoTangent(), dα, dβ, dba...
Expand Down
Loading