diff --git a/deps/ReactantExtra/make-bindings.jl b/deps/ReactantExtra/make-bindings.jl index f84309fef1..ef14ab82f1 100644 --- a/deps/ReactantExtra/make-bindings.jl +++ b/deps/ReactantExtra/make-bindings.jl @@ -22,26 +22,26 @@ end src_dir = joinpath(dirname(dirname(@__DIR__)), "src") for file in [ - "Builtin.jl", - "Arith.jl", - "Affine.jl", - "Func.jl", - "Enzyme.jl", + # "Builtin.jl", + # "Arith.jl", + # "Affine.jl", + # "Func.jl", + # "Enzyme.jl", "EnzymeXLA.jl", - "StableHLO.jl", - "CHLO.jl", - "VHLO.jl", - "Llvm.jl", - "Nvvm.jl", - "Gpu.jl", - "Affine.jl", - "TPU.jl", - "MosaicGPU.jl", - "Triton.jl", - "Shardy.jl", - "MPI.jl", - "MemRef.jl", - "SparseTensor.jl", + # "StableHLO.jl", + # "CHLO.jl", + # "VHLO.jl", + # "Llvm.jl", + # "Nvvm.jl", + # "Gpu.jl", + # "Affine.jl", + # "TPU.jl", + # "MosaicGPU.jl", + # "Triton.jl", + # "Shardy.jl", + # "MPI.jl", + # "MemRef.jl", + # "SparseTensor.jl", ] build_file(joinpath(src_dir, "mlir", "Dialects", file)) end diff --git a/src/Ops.jl b/src/Ops.jl index 39c9790dea..20c9640754 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -934,6 +934,34 @@ end return TracedRArray{T,N}((), MLIR.IR.result(conv), result_size) end +@noinline function lapack_symm( + A::TracedRArray{T}, + B::TracedRArray{T}, + C::TracedRArray{T}, + alpha::TracedRNumber{T}, + beta::TracedRNumber{T}; + side::Symbol, + uplo::Symbol, + location=mlir_stacktrace("lapack_symm", @__FILE__, @__LINE__), +) where {T} + ctx = MLIR.IR.context() + ressize = size(C) + res = MLIR.IR.result( + enzymexla.lapack_symm( + A.mlir_data, + B.mlir_data, + C.mlir_data, + alpha.mlir_data, + beta.mlir_data; + output=mlir_type(TracedRArray{eltype(C),length(ressize)}, ressize), + side=enzymexlaLapackSideAttrGet(ctx, side == :L ? 1 : 0), + uplo=enzymexlaLapackUploAttrGet(ctx, uplo == :U ? 1 : 0), + location, + ), + ) + return res +end + Base.@nospecializeinfer @noinline function dot_general( @nospecialize(lhs::TracedRArray{T1}), @nospecialize(rhs::TracedRArray{T2}); diff --git a/src/stdlibs/LinearAlgebra.jl b/src/stdlibs/LinearAlgebra.jl index d4820c9966..43fd681197 100644 --- a/src/stdlibs/LinearAlgebra.jl +++ b/src/stdlibs/LinearAlgebra.jl @@ -273,6 +273,38 @@ function overloaded_mul!( return C end +function overloaded_mul!( + @nospecialize(C::TracedRArray{T,2} where {T}), + @nospecialize(A::Symmetric), + @nospecialize(B::AbstractMatrix), + α::Number=true, + β::Number=true, +) + # Promote to traced arrays + A = call_with_reactant(Reactant.promote_to, TracedRArray, A.data) + B = call_with_reactant(Reactant.promote_to, TracedRArray, B) + + # Dimension checks + if size(C) != (size(A, 1), size(B, 2)) + throw(DimensionMismatch("C=$(size(C)), A=$(size(A)), B=$(size(B))")) + end + + T = Reactant.unwrapped_eltype(C) + tmp = @opcall lapack_symm( + T.(materialize_traced_array(A)), + T.(materialize_traced_array(B)), + T.(materialize_traced_array(C)), + Reactant.promote_to(TracedRNumber{T}, α), + Reactant.promote_to(TracedRNumber{T}, β), + side=:L, + uplo=:U, + ) + + set_mlir_data!(C, get_mlir_data(tmp)) # TODO remove later, handling in place ops are weird + return C +end + + function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T} iota_1 = @opcall iota(Int64, [size(X)...]; iota_dimension=1) iota_2 = @opcall subtract( diff --git a/test/integration/linear_algebra.jl b/test/integration/linear_algebra.jl index 5790bfc928..e6bc28913b 100644 --- a/test/integration/linear_algebra.jl +++ b/test/integration/linear_algebra.jl @@ -432,3 +432,34 @@ end 1e-2 end end + +@testset "Symmetric Multiplication" begin + @testset "F32" begin + A = Symmetric(rand(Float32,(10,10))) + B = rand(Float32,(10,10)) + C = rand(Float32,(10,10)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + + alpha = rand(Float32) + beta = rand(Float32) + + @test @code_hlo optimize=false A_ra * B_ra * alpha + + end + @testset "F64" begin + A = Symmetric(rand(Float64,(10,10))) + B = rand(Float64,(10,10)) + C = rand(Float64,(10,10)) + A_ra = Reactant.to_rarray(A) + B_ra = Reactant.to_rarray(B) + C_ra = Reactant.to_rarray(C) + + alpha = rand(Float64) + beta = rand(Float64) + + @test @code_hlo optimize=false A_ra * B_ra * alpha + + end +end \ No newline at end of file