diff --git a/ext/ReactantKernelAbstractionsExt.jl b/ext/ReactantKernelAbstractionsExt.jl index ee13c3bb3f..a2e9fab614 100644 --- a/ext/ReactantKernelAbstractionsExt.jl +++ b/ext/ReactantKernelAbstractionsExt.jl @@ -1,6 +1,7 @@ module ReactantKernelAbstractionsExt using Reactant: Reactant +using ReactantCore: ReactantCore using Adapt: Adapt using KernelAbstractions: KernelAbstractions @@ -101,6 +102,14 @@ function tokw(ndrange, workgroupsize, obj, args...) end function (obj::KA.Kernel{ReactantBackend})(args...; ndrange=nothing, workgroupsize=nothing) + # If we're already inside a compilation/tracing context, or if any arguments are traced, + # we should trace through this kernel call instead of trying to compile it again. + if Reactant.within_compile() || any(ReactantCore.is_traced, args) + return Reactant.call_with_reactant( + Reactant.ka_with_reactant, ndrange, workgroupsize, obj, args... + ) + end + if Reactant.precompiling() Reactant.@code_hlo optimize = false tokw(ndrange, workgroupsize, obj, args...) else diff --git a/ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl b/ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl index b782ad62ad..4344371023 100644 --- a/ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl +++ b/ext/ReactantSparseArraysExt/ReactantSparseArraysExt.jl @@ -7,4 +7,6 @@ using SparseArrays: include("Errors.jl") include("ReadOnly.jl") +Reactant.use_overlayed_version(::AbstractSparseArray) = false + end diff --git a/src/Overlay.jl b/src/Overlay.jl index 143e8a4a45..13fa9a4880 100644 --- a/src/Overlay.jl +++ b/src/Overlay.jl @@ -121,10 +121,11 @@ for (cT, aT, bT) in ( @reactant_overlay @noinline function LinearAlgebra.mul!( C::$cT, A::$aT, B::$bT, α::Number, β::Number ) - A, B = aos_to_soa(A), aos_to_soa(B) + A2, B2 = aos_to_soa(A), aos_to_soa(B) C2 = aos_to_soa(C) - if use_overlayed_version((C2, A, B)) - TracedLinearAlgebra.overloaded_mul!(C2, A, B, α, β) + # A2 can also be a SparseMatrix, which should be handled by its own methods + if use_overlayed_version(A2) && use_overlayed_version((C2, A2, B2)) + TracedLinearAlgebra.overloaded_mul!(C2, A2, B2, α, β) if C2 !== C C .= C2 end @@ -132,7 +133,7 @@ for (cT, aT, bT) in ( # Inference barrier is required when calling function recursively within # overload. This is required since otherwise type inference will think this # is a recursive edge rather than a call to the base method - Base.inferencebarrier(LinearAlgebra.mul!)(C, A, B, α, β) + Base.inferencebarrier(LinearAlgebra.mul!)(C2, A2, B2, α, β) end return C end