Skip to content

Commit 4531bd8

Browse files
authored
feat: factorizations (#1234)
* fix: mlir regeneration * feat: factorizations * chore: fmt * fix: symname * fix: flags * feat: lapack integration working 🎉 * feat: Ops.lu * feat: add triangular_solve op * chore: rename * test: add tests for lu fact * chore: fmt * test: fix * chore: bump jll
1 parent 4025e2a commit 4531bd8

File tree

5 files changed

+115
-1
lines changed

5 files changed

+115
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
8989
ReactantCore = "0.1.9"
90-
Reactant_jll = "0.0.182"
90+
Reactant_jll = "0.0.183"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"
9393
Sockets = "1.10"

src/Compiler.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module Compiler
22

33
using Reactant_jll
44
using Libdl: dlsym
5+
using LinearAlgebra: BLAS
56

67
import ..Reactant:
78
Reactant,
@@ -1273,6 +1274,10 @@ function compile_mlir!(
12731274
"canonicalize"
12741275
end
12751276

1277+
blas_int_width = sizeof(BLAS.BlasInt) * 8
1278+
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
1279+
blas_int_width=$blas_int_width}"
1280+
12761281
if optimize === :all
12771282
run_pass_pipeline!(
12781283
mod,
@@ -1291,6 +1296,7 @@ function compile_mlir!(
12911296
"remove-unnecessary-enzyme-ops",
12921297
"enzyme-simplify-math",
12931298
opt_passes2,
1299+
lower_enzymexla_linalg_pass,
12941300
jit,
12951301
]
12961302
else
@@ -1307,6 +1313,7 @@ function compile_mlir!(
13071313
opt_passes2,
13081314
kern,
13091315
raise_passes,
1316+
lower_enzymexla_linalg_pass,
13101317
jit,
13111318
]
13121319
end,
@@ -1453,6 +1460,7 @@ function compile_mlir!(
14531460
"remove-unnecessary-enzyme-ops",
14541461
"enzyme-simplify-math",
14551462
opt_passes2,
1463+
lower_enzymexla_linalg_pass,
14561464
jit,
14571465
]
14581466
else
@@ -1466,6 +1474,7 @@ function compile_mlir!(
14661474
opt_passes2,
14671475
kern,
14681476
raise_passes,
1477+
lower_enzymexla_linalg_pass,
14691478
jit,
14701479
]
14711480
end,
@@ -1487,6 +1496,7 @@ function compile_mlir!(
14871496
opt_passes2,
14881497
enzyme_pass,
14891498
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
1499+
lower_enzymexla_linalg_pass,
14901500
jit,
14911501
]
14921502
else
@@ -1499,6 +1509,7 @@ function compile_mlir!(
14991509
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
15001510
kern,
15011511
raise_passes,
1512+
lower_enzymexla_linalg_pass,
15021513
jit,
15031514
]
15041515
end,

src/Ops.jl

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2930,4 +2930,48 @@ function triangular_solve(
29302930
return TracedRArray{T,N}((), res, size(res))
29312931
end
29322932

2933+
"""
2934+
lu(
2935+
x::TracedRArray{T},
2936+
::Type{pT}=Int32;
2937+
location=mlir_stacktrace("lu", @__FILE__, @__LINE__)
2938+
) where {T,pT}
2939+
2940+
Compute the row maximum pivoted LU factorization of `x` and return the factors `LU`,
2941+
`ipiv`, `permutation` tensor, and `info`.
2942+
"""
2943+
@noinline function lu(
2944+
x::TracedRArray{T},
2945+
::Type{pT}=Int32;
2946+
location=mlir_stacktrace("lu", @__FILE__, @__LINE__),
2947+
) where {T,pT}
2948+
@assert ndims(x) >= 2
2949+
2950+
output_shape = collect(Int64, size(x))
2951+
batch_shape = output_shape[1:(end - 2)]
2952+
pivots_shape = vcat(batch_shape, min(size(x, ndims(x) - 1), size(x, ndims(x))))
2953+
permutation_shape = vcat(batch_shape, size(x, ndims(x) - 1))
2954+
info_shape = batch_shape
2955+
2956+
op = MLIR.Dialects.enzymexla.linalg_lu(
2957+
x.mlir_data;
2958+
output=MLIR.IR.TensorType(output_shape, MLIR.IR.Type(unwrapped_eltype(T))),
2959+
pivots=MLIR.IR.TensorType(pivots_shape, MLIR.IR.Type(pT)),
2960+
permutation=MLIR.IR.TensorType(permutation_shape, MLIR.IR.Type(pT)),
2961+
info=MLIR.IR.TensorType(info_shape, MLIR.IR.Type(pT)),
2962+
location,
2963+
)
2964+
2965+
res = TracedRArray{T,ndims(x)}((), MLIR.IR.result(op, 1), size(x))
2966+
ipiv = TracedRArray{pT,ndims(x) - 1}((), MLIR.IR.result(op, 2), pivots_shape)
2967+
perm = TracedRArray{pT,ndims(x) - 1}((), MLIR.IR.result(op, 3), permutation_shape)
2968+
2969+
if ndims(x) == 2
2970+
info = TracedRNumber{pT}((), MLIR.IR.result(op, 4))
2971+
else
2972+
info = TracedRArray{pT,ndims(x) - 2}((), MLIR.IR.result(op, 4), info_shape)
2973+
end
2974+
return (res, ipiv, perm, info)
2975+
end
2976+
29332977
end # module Ops

src/stdlibs/LinearAlgebra.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,25 @@ using ReactantCore: materialize_traced_array
1818
using ..TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!
1919

2020
using LinearAlgebra
21+
using Libdl: Libdl
22+
23+
function __init__()
24+
libblastrampoline_handle = Libdl.dlopen(LinearAlgebra.BLAS.libblas)
25+
26+
for (cname, enzymexla_name) in [
27+
(LinearAlgebra.BLAS.@blasfunc(sgetrf_), :enzymexla_lapack_sgetrf_),
28+
(LinearAlgebra.BLAS.@blasfunc(dgetrf_), :enzymexla_lapack_dgetrf_),
29+
(LinearAlgebra.BLAS.@blasfunc(cgetrf_), :enzymexla_lapack_cgetrf_),
30+
(LinearAlgebra.BLAS.@blasfunc(zgetrf_), :enzymexla_lapack_zgetrf_),
31+
]
32+
sym = Libdl.dlsym(libblastrampoline_handle, cname)
33+
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(
34+
enzymexla_name::Cstring, sym::Ptr{Cvoid}
35+
)::Cvoid
36+
end
37+
38+
return nothing
39+
end
2140

2241
# Various Wrapper Arrays defined in LinearAlgebra
2342
function ReactantCore.materialize_traced_array(

test/ops.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,3 +1180,43 @@ end
11801180
@test Array(y_ra) == ones(Float32, 2, 3)
11811181
end
11821182
end
1183+
1184+
function recon_from_lu(lu_res::AbstractArray{T,4}) where {T}
1185+
y = similar(lu_res)
1186+
for i in 1:size(lu_res, 1), j in 1:size(lu_res, 2)
1187+
y[i, j, :, :] .= recon_from_lu(lu_res[i, j, :, :])
1188+
end
1189+
return y
1190+
end
1191+
1192+
function apply_permutation(x::AbstractArray{T,4}, perm) where {T}
1193+
y = similar(x)
1194+
for i in 1:size(x, 1), j in 1:size(x, 2)
1195+
y[i, j, :, :] .= x[i, j, perm[i, j, :], :]
1196+
end
1197+
return y
1198+
end
1199+
1200+
function recon_from_lu(lu_res::AbstractMatrix)
1201+
return UnitLowerTriangular(lu_res) * UpperTriangular(lu_res)
1202+
end
1203+
1204+
@testset "lu factorization" begin
1205+
@testset "unbatched" begin
1206+
x_ra = Reactant.to_rarray(randn(6, 6))
1207+
lu_ra, ipiv, perm, info = @jit Ops.lu(x_ra)
1208+
1209+
@test @jit(recon_from_lu(lu_ra)) @jit(getindex(x_ra, perm, :))
1210+
end
1211+
1212+
@testset "batched" begin
1213+
x_ra = Reactant.to_rarray(randn(4, 3, 6, 6))
1214+
lu_ra, ipiv, perm, info = @jit Ops.lu(x_ra)
1215+
@test size(lu_ra) == (4, 3, 6, 6)
1216+
@test size(ipiv) == (4, 3, 6)
1217+
@test size(perm) == (4, 3, 6)
1218+
@test size(info) == (4, 3)
1219+
1220+
@test @jit(recon_from_lu(lu_ra)) @jit(apply_permutation(x_ra, perm))
1221+
end
1222+
end

0 commit comments

Comments
 (0)