Draft
Conversation
e80a40a to
a7ccbb8
Compare
Contributor
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/ext/TensorKitEnzymeExt/indexmanipulations.jl b/ext/TensorKitEnzymeExt/indexmanipulations.jl
index 1d9a7a9..d3a634f 100644
--- a/ext/TensorKitEnzymeExt/indexmanipulations.jl
+++ b/ext/TensorKitEnzymeExt/indexmanipulations.jl
@@ -59,8 +59,8 @@ for transform in (:permute, :transpose)
TK.$add_transform!(A.dval, C.dval, pΔA, conj(α.val), One(), bavs...)
end
end
- Δα = pullback_dα(α, C, Ap)
- Δβ = pullback_dβ(β, C, Cval)
+ Δα = pullback_dα(α, C, Ap)
+ Δβ = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)
return nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)...
end
@@ -128,8 +128,8 @@ function EnzymeRules.reverse(
TK.add_braid!(A.dval, C.dval, pΔA, ilevels, conj(α.val), One(), bavs...)
end
end
- Δαr = pullback_dα(α, C, Ap)
- Δβr = pullback_dβ(β, C, Cval)
+ Δαr = pullback_dα(α, C, Ap)
+ Δβr = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)
return nothing, nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)...
end
diff --git a/ext/TensorKitEnzymeExt/linalg.jl b/ext/TensorKitEnzymeExt/linalg.jl
index b9b0e45..1a2b9f1 100644
--- a/ext/TensorKitEnzymeExt/linalg.jl
+++ b/ext/TensorKitEnzymeExt/linalg.jl
@@ -51,8 +51,8 @@ function EnzymeRules.reverse(
!isa(A, Const) && !isa(C, Const) && project_mul!(A.dval, C.dval, Bval', conj(α.val))
!isa(B, Const) && !isa(C, Const) && project_mul!(B.dval, Aval', C.dval, conj(α.val))
- Δαr = pullback_dα(α, C, AB)
- Δβr = pullback_dβ(β, C, Cval)
+ Δαr = pullback_dα(α, C, AB)
+ Δβr = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)
return (nothing, nothing, nothing, Δαr, Δβr)
diff --git a/ext/TensorKitEnzymeExt/tensoroperations.jl b/ext/TensorKitEnzymeExt/tensoroperations.jl
index a3b598c..69a48e3 100644
--- a/ext/TensorKitEnzymeExt/tensoroperations.jl
+++ b/ext/TensorKitEnzymeExt/tensoroperations.jl
@@ -62,8 +62,8 @@ function EnzymeRules.reverse(
Aval = something(cacheA, A.val)
Bval = something(cacheB, B.val)
- Δα = pullback_dα(α, C, AB)
- Δβ = pullback_dβ(β, C, Cval)
+ Δα = pullback_dα(α, C, AB)
+ Δβ = pullback_dβ(β, C, Cval)
if !isa(A, Const)
blas_contract_pullback_ΔA!(
@@ -179,8 +179,8 @@ function EnzymeRules.reverse(
Aval = something(A_cache, A.val)
Cval = something(C_cache, C.val)
!isa(A, Const) && !isa(C, Const) && trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val)
- Δαr = pullback_dα(α, C, At)
- Δβr = pullback_dβ(β, C, Cval)
+ Δαr = pullback_dα(α, C, At)
+ Δβr = pullback_dβ(β, C, Cval)
!isa(C, Const) && pullback_dC!(C.dval, β.val)
return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing
end
diff --git a/ext/TensorKitEnzymeExt/vectorinterface.jl b/ext/TensorKitEnzymeExt/vectorinterface.jl
index a0dd107..977f981 100644
--- a/ext/TensorKitEnzymeExt/vectorinterface.jl
+++ b/ext/TensorKitEnzymeExt/vectorinterface.jl
@@ -21,7 +21,7 @@ function EnzymeRules.reverse(
α::Annotation{<:Number},
) where {RT}
Cval = something(cache, C.val)
- Δα = pullback_dα(α, C, Cval)
+ Δα = pullback_dα(α, C, Cval)
!isa(C, Const) && scale!(C.dval, conj(α.val))
return (nothing, Δα)
end
@@ -52,7 +52,7 @@ function EnzymeRules.reverse(
α::Annotation{<:Number},
) where {RT}
Aval = something(cache, A.val)
- Δα = pullback_dα(α, C, Aval)
+ Δα = pullback_dα(α, C, Aval)
!isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(α.val))
!isa(C, Const) && zerovector!(C.dval)
return (nothing, nothing, Δα)
@@ -89,8 +89,8 @@ function EnzymeRules.reverse(
A_cache, C_cache = cache
Aval = something(A_cache, A.val)
Cval = something(C_cache, C.val)
- Δα = pullback_dα(α, C, Aval)
- Δβ = pullback_dβ(β, C, Cval)
+ Δα = pullback_dα(α, C, Aval)
+ Δβ = pullback_dβ(β, C, Cval)
!isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(α.val))
!isa(C, Const) && scale!(C.dval, conj(β.val))
return (nothing, nothing, Δα, Δβ) |
d308db9 to
6ac8e67
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Building on top of #356
First I need to get all the tests working. I am getting some odd space mismatch errors I need to dig into further as well as a few Enzyme-related segfaults.
Then, I'd like to refactor the joint code shared between Mooncake and Enzyme into the main
srcfolder so it can be reused (much as we've done with MatrixAlgebraKit).