Skip to content

Commit f5b3cc2

Browse files
authored
feat: triangular_solve lowering (#1291)
* feat: triangular_solve lowering * feat: more overloads * feat: overload generic_trimatdiv! * feat: add rdiv! support * test: add tests
1 parent 7d3d6a0 commit f5b3cc2

File tree

3 files changed

+153
-0
lines changed

3 files changed

+153
-0
lines changed

src/Ops.jl

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2846,4 +2846,66 @@ end
28462846
]
28472847
end
28482848

2849+
function triangular_solve(
2850+
a::TracedRArray{T,N},
2851+
b::TracedRArray{T,M};
2852+
left_side::Bool,
2853+
location=mlir_stacktrace("triangular_solve", @__FILE__, @__LINE__),
2854+
kwargs...,
2855+
) where {T,N,M}
2856+
@assert M == N - 1
2857+
2858+
if left_side
2859+
b = reshape(b, size(b)..., 1)
2860+
else
2861+
b = reshape(b, size(b)[1:(M - 1)]..., 1, size(b, M))
2862+
end
2863+
2864+
return dropdims(
2865+
triangular_solve(a, b; location, left_side, kwargs...); dims=(N - 1 + left_side)
2866+
)
2867+
end
2868+
2869+
function triangular_solve(
2870+
a::TracedRArray{T,N},
2871+
b::TracedRArray{T,N};
2872+
left_side::Bool,
2873+
lower::Bool,
2874+
transpose_a::Char,
2875+
unit_diagonal::Bool,
2876+
location=mlir_stacktrace("triangular_solve", @__FILE__, @__LINE__),
2877+
) where {T,N}
2878+
@assert N >= 2
2879+
@assert size(a, N - 1) == size(a, N) == size(b, N - left_side)
2880+
@assert size(a)[1:(N - 2)] == size(b)[1:(N - 2)] "a and b must have the same leading \
2881+
dimensions"
2882+
2883+
@assert transpose_a in ('N', 'T', 'C') "transpose_a must be one of 'N', 'T', or 'C'"
2884+
transpose_attr = MLIR.API.stablehloTransposeAttrGet(
2885+
MLIR.IR.context(),
2886+
if transpose_a == 'N'
2887+
"NO_TRANSPOSE"
2888+
elseif transpose_a == 'T'
2889+
"TRANSPOSE"
2890+
else
2891+
"ADJOINT"
2892+
end,
2893+
)
2894+
2895+
res = MLIR.IR.result(
2896+
MLIR.Dialects.stablehlo.triangular_solve(
2897+
a.mlir_data,
2898+
b.mlir_data;
2899+
left_side=left_side,
2900+
lower=lower,
2901+
transpose_a=transpose_attr,
2902+
unit_diagonal=unit_diagonal,
2903+
location,
2904+
),
2905+
1,
2906+
)
2907+
2908+
return TracedRArray{T,N}((), res, size(res))
2909+
end
2910+
28492911
end # module Ops

src/stdlibs/LinearAlgebra.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,4 +469,55 @@ function LinearAlgebra.dot(x::AnyTracedRVector, y::AnyTracedRVector)
469469
return TracedRNumber{unwrapped_eltype(res)}((), res.mlir_data)
470470
end
471471

472+
# ldiv & rdiv interfaces
473+
tfun_to_char(::typeof(identity)) = 'N'
474+
tfun_to_char(::typeof(transpose)) = 'T'
475+
tfun_to_char(::typeof(adjoint)) = 'C'
476+
477+
function LinearAlgebra.generic_trimatdiv!(
478+
C::AbstractVecOrMat{TracedRNumber{T}},
479+
uploc,
480+
isunitc,
481+
tfun::Function,
482+
A::AbstractMatrix,
483+
B::AbstractVecOrMat,
484+
) where {T}
485+
@assert uploc in ('L', 'U')
486+
@assert isunitc in ('N', 'U')
487+
488+
res = Ops.triangular_solve(
489+
TracedUtils.promote_to(TracedRArray{T,2}, materialize_traced_array(A)),
490+
TracedUtils.promote_to(TracedRArray{T,ndims(B)}, materialize_traced_array(B));
491+
left_side=true,
492+
lower=(uploc == 'L'),
493+
transpose_a=tfun_to_char(tfun),
494+
unit_diagonal=(isunitc == 'U'),
495+
)
496+
set_mlir_data!(C, get_mlir_data(res))
497+
return C
498+
end
499+
500+
function LinearAlgebra.generic_mattridiv!(
501+
C::AbstractMatrix{TracedRNumber{T}},
502+
uploc,
503+
isunitc,
504+
tfun::Function,
505+
A::AbstractMatrix,
506+
B::AbstractMatrix,
507+
) where {T}
508+
@assert uploc in ('L', 'U')
509+
@assert isunitc in ('N', 'U')
510+
511+
res = Ops.triangular_solve(
512+
TracedUtils.promote_to(TracedRArray{T,2}, materialize_traced_array(B)),
513+
TracedUtils.promote_to(TracedRArray{T,2}, materialize_traced_array(A));
514+
left_side=false,
515+
lower=(uploc == 'L'),
516+
transpose_a=tfun_to_char(tfun),
517+
unit_diagonal=(isunitc == 'U'),
518+
)
519+
set_mlir_data!(C, get_mlir_data(res))
520+
return C
521+
end
522+
472523
end

test/integration/linear_algebra.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,3 +279,43 @@ end
279279

280280
@test @jit(dot(x_ra, y_ra)) dot(x, y)
281281
end
282+
283+
@testset "Triangular ldiv and rdiv" begin
284+
fn1(A, b) = A \ b
285+
fn2(A, b) = A' \ b
286+
fn3(A, b) = transpose(A) \ b
287+
288+
fn4(A, B) = B / A
289+
fn5(A, B) = B / A'
290+
fn6(A, B) = B / transpose(A)
291+
292+
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
293+
A = rand(T, 6, 6)
294+
B = rand(T, 6, 6)
295+
b = rand(T, 6)
296+
b_ra = Reactant.to_rarray(b)
297+
B_ra = Reactant.to_rarray(B)
298+
299+
@testset for wT in (
300+
UnitLowerTriangular, UnitUpperTriangular, LowerTriangular, UpperTriangular
301+
)
302+
A_wrapped = wT(A)
303+
A_ra = Reactant.to_rarray(A_wrapped)
304+
305+
@testset "no_tranpose" begin
306+
@test @jit(fn1(A_ra, b_ra)) fn1(A_wrapped, b)
307+
@test @jit(fn4(A_ra, B_ra)) fn4(A_wrapped, B)
308+
end
309+
310+
@testset "adjoint" begin
311+
@test @jit(fn2(A_ra, b_ra)) fn2(A_wrapped, b)
312+
@test @jit(fn5(A_ra, B_ra)) fn5(A_wrapped, B)
313+
end
314+
315+
@testset "transpose" begin
316+
@test @jit(fn3(A_ra, b_ra)) fn3(A_wrapped, b)
317+
@test @jit(fn6(A_ra, B_ra)) fn6(A_wrapped, B)
318+
end
319+
end
320+
end
321+
end

0 commit comments

Comments
 (0)