From 77b05be0d6ee927272f043a5bfe590a1ef7cd251 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Fri, 10 Oct 2025 18:51:30 +0200 Subject: [PATCH 1/3] Don't overlay mul! for sparse arrays --- ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl | 2 ++ src/Overlay.jl | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl b/ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl index b782ad62ad..4344371023 100644 --- a/ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl +++ b/ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl @@ -7,4 +7,6 @@ using SparseArrays: include("Errors.jl") include("ReadOnly.jl") +Reactant.use_overlayed_version(::AbstractSparseArray) = false + end diff --git a/src/Overlay.jl b/src/Overlay.jl index 143e8a4a45..e6d07f7e77 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -114,8 +114,8 @@ end ## `mul!` goes through too many layers of abstractions and we aren't able to overload ## without specializing on every possible combination of types for (cT, aT, bT) in ( - (:AbstractVector, :AbstractMatrix, :AbstractVector), - (:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat), + (:AbstractVector, :DenseMatrix, :AbstractVector), + (:AbstractMatrix, :DenseMatrix, :AbstractVecOrMat), ) @eval begin @reactant_overlay @noinline function LinearAlgebra.mul!( From 7cc2786e2261b829ae87295dea89345e023a56f9 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Sun, 12 Oct 2025 13:34:52 +0200 Subject: [PATCH 2/3] Relav method definition and add check for matrix only --- src/Overlay.jl | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Overlay.jl b/src/Overlay.jl index e6d07f7e77..acc0474810 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -114,17 +114,18 @@ end ## `mul!` goes through too many layers of abstractions and we aren't able to overload ## without specializing on every possible combination of types for (cT, aT, bT) in ( - (:AbstractVector, :DenseMatrix, :AbstractVector), - (:AbstractMatrix, :DenseMatrix, :AbstractVecOrMat), + (:AbstractVector, :AbstractMatrix, :AbstractVector), + (:AbstractMatrix, :AbstractMatrix, :AbstractVecOrMat), ) @eval begin @reactant_overlay @noinline function LinearAlgebra.mul!( C::$cT, A::$aT, B::$bT, α::Number, β::Number ) - A, B = aos_to_soa(A), aos_to_soa(B) + A2, B2 = aos_to_soa(A), aos_to_soa(B) C2 = aos_to_soa(C) - if use_overlayed_version((C2, A, B)) - TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β) + # A2 can also be a SparseMatrix, which should be handled by its own methods + if use_overlayed_version(A2) && use_overlayed_version((C2, A2, B2)) + TracedLinearAlgebra.overloaded_mul!(C2, A2, B2, α, β) if C2 !== C C .= C2 end From 563c4b86fe204efb416a83a2d4db6fd61eccadf4 Mon Sep 17 00:00:00 2001 From: Alberto Mercurio Date: Wed, 22 Oct 2025 13:55:57 +0200 Subject: [PATCH 3/3] Tracing check within KernelAbstractions --- ext/ReactantKernelAbstractionsExt.jl | 9 +++++++++ src/Overlay.jl | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/ext/ReactantKernelAbstractionsExt.jl b/ext/ReactantKernelAbstractionsExt.jl index ee13c3bb3f..a2e9fab614 100644 --- a/ext/ReactantKernelAbstractionsExt.jl +++ b/ext/ReactantKernelAbstractionsExt.jl @@ -1,6 +1,7 @@ module ReactantKernelAbstractionsExt using Reactant: Reactant +using ReactantCore: ReactantCore using Adapt: Adapt using KernelAbstractions: KernelAbstractions @@ -101,6 +102,14 @@ function tokw(ndrange, workgroupsize, obj, args...) end function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing) + # If we're already inside a compilation/tracing context, or if any arguments are traced, + # we should trace through this kernel call instead of trying to compile it again. + if Reactant.within_compile() || any(ReactantCore.is_traced, args) + return Reactant.call_with_reactant( + Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args... + ) + end + if Reactant.precompiling() Reactant.@code_hlo optimize = false tokw(ndrange, workgroupsize, obj, args...) else diff --git a/src/Overlay.jl b/src/Overlay.jl index acc0474810..13fa9a4880 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -133,7 +133,7 @@ for (cT, aT, bT) in ( # Inference barrier is required when calling function recursively within # overload. This is required since otherwise type inference will think this # is a recursive edge rather than a call to the base method - Base.inferencebarrier(LinearAlgebra.mul!)(C, A, B, α, β) + Base.inferencebarrier(LinearAlgebra.mul!)(C2, A2, B2, α, β) end return C end