Skip to content

Incremental work on Enzyme support#380

Draft
kshyatt wants to merge 5 commits intomainfrom
ksh/enzyme
Draft

Incremental work on Enzyme support#380
kshyatt wants to merge 5 commits intomainfrom
ksh/enzyme

Conversation

@kshyatt
Copy link
Member

@kshyatt kshyatt commented Mar 12, 2026

Building on top of #356

First I need to get all the tests working. I am getting some odd space mismatch errors I need to dig into further as well as a few Enzyme-related segfaults.

Then, I'd like to refactor the joint code shared between Mooncake and Enzyme into the main src folder so it can be reused (much as we've done with MatrixAlgebraKit).

@kshyatt kshyatt force-pushed the ksh/enzyme branch 2 times, most recently from e80a40a to a7ccbb8 Compare March 18, 2026 12:46
@github-actions
Copy link
Contributor

github-actions bot commented Mar 18, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic main) to apply these changes.

Click here to view the suggested changes.
diff --git a/ext/TensorKitEnzymeExt/indexmanipulations.jl b/ext/TensorKitEnzymeExt/indexmanipulations.jl
index 1d9a7a9..d3a634f 100644
--- a/ext/TensorKitEnzymeExt/indexmanipulations.jl
+++ b/ext/TensorKitEnzymeExt/indexmanipulations.jl
@@ -59,8 +59,8 @@ for transform in (:permute, :transpose)
                 TK.$add_transform!(A.dval, C.dval, pΔA, conj(α.val), One(), bavs...)
             end
         end
-        Δα = pullback_dα(α, C, Ap) 
-        Δβ = pullback_dβ(β, C, Cval) 
+        Δα = pullback_dα(α, C, Ap)
+        Δβ = pullback_dβ(β, C, Cval)
         !isa(C, Const) && pullback_dC!(C.dval, β.val)
         return nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)...
     end
@@ -128,8 +128,8 @@ function EnzymeRules.reverse(
             TK.add_braid!(A.dval, C.dval, pΔA, ilevels, conj(α.val), One(), bavs...)
         end
     end
-    Δαr = pullback_dα(α, C, Ap) 
-    Δβr = pullback_dβ(β, C, Cval) 
+    Δαr = pullback_dα(α, C, Ap)
+    Δβr = pullback_dβ(β, C, Cval)
     !isa(C, Const) && pullback_dC!(C.dval, β.val)
     return nothing, nothing, nothing, nothing, Δαr, Δβr, map(Returns(nothing), ba)...
 end
diff --git a/ext/TensorKitEnzymeExt/linalg.jl b/ext/TensorKitEnzymeExt/linalg.jl
index b9b0e45..1a2b9f1 100644
--- a/ext/TensorKitEnzymeExt/linalg.jl
+++ b/ext/TensorKitEnzymeExt/linalg.jl
@@ -51,8 +51,8 @@ function EnzymeRules.reverse(
 
     !isa(A, Const) && !isa(C, Const) && project_mul!(A.dval, C.dval, Bval', conj(α.val))
     !isa(B, Const) && !isa(C, Const) && project_mul!(B.dval, Aval', C.dval, conj(α.val))
-    Δαr = pullback_dα(α, C, AB) 
-    Δβr = pullback_dβ(β, C, Cval) 
+    Δαr = pullback_dα(α, C, AB)
+    Δβr = pullback_dβ(β, C, Cval)
     !isa(C, Const) && pullback_dC!(C.dval, β.val)
 
     return (nothing, nothing, nothing, Δαr, Δβr)
diff --git a/ext/TensorKitEnzymeExt/tensoroperations.jl b/ext/TensorKitEnzymeExt/tensoroperations.jl
index a3b598c..69a48e3 100644
--- a/ext/TensorKitEnzymeExt/tensoroperations.jl
+++ b/ext/TensorKitEnzymeExt/tensoroperations.jl
@@ -62,8 +62,8 @@ function EnzymeRules.reverse(
     Aval = something(cacheA, A.val)
     Bval = something(cacheB, B.val)
 
-    Δα = pullback_dα(α, C, AB) 
-    Δβ = pullback_dβ(β, C, Cval) 
+    Δα = pullback_dα(α, C, AB)
+    Δβ = pullback_dβ(β, C, Cval)
 
     if !isa(A, Const)
         blas_contract_pullback_ΔA!(
@@ -179,8 +179,8 @@ function EnzymeRules.reverse(
     Aval = something(A_cache, A.val)
     Cval = something(C_cache, C.val)
     !isa(A, Const) && !isa(C, Const) && trace_permute_pullback_ΔA!(A.dval, C.dval, Aval, p.val, q.val, α.val, backend.val)
-    Δαr = pullback_dα(α, C, At) 
-    Δβr = pullback_dβ(β, C, Cval) 
+    Δαr = pullback_dα(α, C, At)
+    Δβr = pullback_dβ(β, C, Cval)
     !isa(C, Const) && pullback_dC!(C.dval, β.val)
     return nothing, nothing, nothing, nothing, Δαr, Δβr, nothing
 end
diff --git a/ext/TensorKitEnzymeExt/vectorinterface.jl b/ext/TensorKitEnzymeExt/vectorinterface.jl
index a0dd107..977f981 100644
--- a/ext/TensorKitEnzymeExt/vectorinterface.jl
+++ b/ext/TensorKitEnzymeExt/vectorinterface.jl
@@ -21,7 +21,7 @@ function EnzymeRules.reverse(
         α::Annotation{<:Number},
     ) where {RT}
     Cval = something(cache, C.val)
-    Δα = pullback_dα(α, C, Cval) 
+    Δα = pullback_dα(α, C, Cval)
     !isa(C, Const) && scale!(C.dval, conj(α.val))
     return (nothing, Δα)
 end
@@ -52,7 +52,7 @@ function EnzymeRules.reverse(
         α::Annotation{<:Number},
     ) where {RT}
     Aval = something(cache, A.val)
-    Δα = pullback_dα(α, C, Aval) 
+    Δα = pullback_dα(α, C, Aval)
     !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(α.val))
     !isa(C, Const) && zerovector!(C.dval)
     return (nothing, nothing, Δα)
@@ -89,8 +89,8 @@ function EnzymeRules.reverse(
     A_cache, C_cache = cache
     Aval = something(A_cache, A.val)
     Cval = something(C_cache, C.val)
-    Δα = pullback_dα(α, C, Aval) 
-    Δβ = pullback_dβ(β, C, Cval) 
+    Δα = pullback_dα(α, C, Aval)
+    Δβ = pullback_dβ(β, C, Cval)
     !isa(A, Const) && !isa(C, Const) && add!(A.dval, C.dval, conj(α.val))
     !isa(C, Const) && scale!(C.dval, conj(β.val))
     return (nothing, nothing, Δα, Δβ)

@codecov
Copy link

codecov bot commented Mar 18, 2026

Codecov Report

❌ Patch coverage is 5.41045% with 507 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
ext/TensorKitEnzymeExt/indexmanipulations.jl 0.00% 140 Missing ⚠️
ext/TensorKitEnzymeExt/tensoroperations.jl 0.00% 86 Missing ⚠️
ext/TensorKitEnzymeExt/linalg.jl 0.00% 82 Missing ⚠️
ext/TensorKitEnzymeExt/vectorinterface.jl 0.00% 75 Missing ⚠️
ext/TensorKitEnzymeExt/planaroperations.jl 0.00% 47 Missing ⚠️
ext/TensorKitEnzymeExt/utility.jl 3.03% 32 Missing ⚠️
ext/TensorKitEnzymeExt/factorizations.jl 0.00% 23 Missing ⚠️
ext/TensorKitEnzymeTestUtilsExt.jl 77.14% 8 Missing ⚠️
src/factorizations/matrixalgebrakit.jl 0.00% 8 Missing ⚠️
src/factorizations/adjoint.jl 0.00% 6 Missing ⚠️
Files with missing lines Coverage Δ
ext/TensorKitEnzymeExt/TensorKitEnzymeExt.jl 100.00% <100.00%> (ø)
src/factorizations/adjoint.jl 65.51% <0.00%> (-7.56%) ⬇️
ext/TensorKitEnzymeTestUtilsExt.jl 77.14% <77.14%> (ø)
src/factorizations/matrixalgebrakit.jl 90.00% <0.00%> (-7.06%) ⬇️
ext/TensorKitEnzymeExt/factorizations.jl 0.00% <0.00%> (ø)
ext/TensorKitEnzymeExt/utility.jl 3.03% <3.03%> (ø)
ext/TensorKitEnzymeExt/planaroperations.jl 0.00% <0.00%> (ø)
ext/TensorKitEnzymeExt/vectorinterface.jl 0.00% <0.00%> (ø)
ext/TensorKitEnzymeExt/linalg.jl 0.00% <0.00%> (ø)
ext/TensorKitEnzymeExt/tensoroperations.jl 0.00% <0.00%> (ø)
... and 1 more

... and 2 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kshyatt kshyatt force-pushed the ksh/enzyme branch 2 times, most recently from d308db9 to 6ac8e67 Compare March 19, 2026 11:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant