Skip to content

Commit 12e9941

Browse files
authored
feat: overload LinearAlgebra.lu (#1297)
* feat: overload LinearAlgebra.lu * feat: some more overloads * feat: implement batched LU with Ops.lu * chore: bump jll * feat: more overloads * fix: batch overload * test: unbatched LU * fix: batch op implementation * fix: batch ordering for lu * fix: ambiguity in 1.10
1 parent a099af7 commit 12e9941

File tree

4 files changed

+235
-29
lines changed

4 files changed

+235
-29
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
8989
ReactantCore = "0.1.9"
90-
Reactant_jll = "0.0.185"
90+
Reactant_jll = "0.0.186"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"
9393
Sockets = "1.10"

src/Ops.jl

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2797,6 +2797,7 @@ end
27972797

27982798
# First we permute and make sure the batch dims are at the beginning
27992799
batch_dims = Int64[i for i in 1:N if i dims]
2800+
batch_shape = [size(A, i) for i in batch_dims]
28002801
permutation = zeros(Int64, N)
28012802
for (i, d) in enumerate(batch_dims)
28022803
permutation[i] = d
@@ -2805,51 +2806,54 @@ end
28052806
permutation[i + length(batch_dims)] = d
28062807
end
28072808

2808-
A = Ops.transpose(A, permutation; location)
2809+
res = only(batch(f, [Ops.transpose(A, permutation; location)], batch_shape; location))
2810+
if ndims(res) != length(permutation)
2811+
res = Ops.reshape(
2812+
res,
2813+
vcat(collect(Int64, size(res)), ones(Int64, length(permutation) - ndims(res))),
2814+
)
2815+
end
2816+
return Ops.transpose(res, invperm(permutation); location)
2817+
end
28092818

2810-
sample_input = fill(T(0), [size(A, i) for i in (length(batch_dims) + 1):N]; location)
2811-
# TODO: detect and forbid internal mutations
2819+
@noinline function batch(
2820+
f::F,
2821+
inputs::Vector{<:TracedRArray},
2822+
batch_shape::Vector{Int64};
2823+
location=mlir_stacktrace("batch", @__FILE__, @__LINE__),
2824+
) where {F}
2825+
sample_inputs = [
2826+
fill(
2827+
unwrapped_eltype(input)(0),
2828+
[size(input, i) for i in (length(batch_shape) + 1):ndims(input)]...,
2829+
) for input in inputs
2830+
]
28122831
mlir_fn_res = Reactant.TracedUtils.make_mlir_fn(
28132832
f,
2814-
(sample_input,),
2833+
(sample_inputs...,),
28152834
(),
28162835
"unbatched_" * string(f),
28172836
false;
28182837
args_in_result=:none,
28192838
do_transpose=false,
28202839
)
2821-
28222840
@assert !mlir_fn_res.fnwrapped "Currently we don't support batching closures."
28232841

28242842
func = mlir_fn_res.f
28252843
@assert MLIR.IR.nregions(func) == 1
28262844

2827-
result = only(mlir_fn_res.linear_results)
2828-
batch_shape = [size(A, i) for i in 1:length(batch_dims)]
2829-
2830-
if result isa TracedRArray
2831-
@assert ndims(result) == ndims(sample_input)
2832-
output_type = MLIR.IR.TensorType(
2833-
vcat(batch_shape, collect(Int64, size(result))),
2834-
MLIR.IR.Type(unwrapped_eltype(result)),
2835-
)
2836-
elseif result isa TracedRNumber
2837-
output_type = MLIR.IR.TensorType(
2838-
batch_shape, MLIR.IR.Type(unwrapped_eltype(result))
2839-
)
2840-
else
2841-
error("Unsupported result type $(typeof(result))")
2842-
end
2843-
2844-
batched_result = batch([A], [output_type], batch_shape; fn=func, location)[1]
2845-
2846-
if result isa TracedRNumber
2847-
batched_result = Ops.reshape(
2848-
batched_result, vcat(batch_shape, ones(Int64, ndims(sample_input))); location
2845+
output_types = MLIR.IR.Type[]
2846+
for result in mlir_fn_res.linear_results
2847+
push!(
2848+
output_types,
2849+
MLIR.IR.TensorType(
2850+
vcat(batch_shape, collect(Int64, size(result))),
2851+
MLIR.IR.Type(unwrapped_eltype(result)),
2852+
),
28492853
)
28502854
end
28512855

2852-
return Ops.transpose(batched_result, invperm(permutation); location)
2856+
return batch(inputs, output_types, batch_shape; fn=func, location)
28532857
end
28542858

28552859
@noinline function batch(

src/stdlibs/LinearAlgebra.jl

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,4 +539,154 @@ function LinearAlgebra.generic_mattridiv!(
539539
return C
540540
end
541541

542+
# Supports batched factorization
543+
abstract type GeneralizedFactorization{T} <: Factorization{T} end
544+
545+
function LinearAlgebra.TransposeFactorization(f::GeneralizedFactorization)
546+
return LinearAlgebra.TransposeFactorization{eltype(f),typeof(f)}(f)
547+
end
548+
549+
function LinearAlgebra.AdjointFactorization(f::GeneralizedFactorization)
550+
return LinearAlgebra.AdjointFactorization{eltype(f),typeof(f)}(f)
551+
end
552+
553+
const GeneralizedTransposeFactorization{T} =
554+
LinearAlgebra.TransposeFactorization{T,<:GeneralizedFactorization{T}} where {T}
555+
const GeneralizedAdjointFactorization{T} =
556+
LinearAlgebra.AdjointFactorization{T,<:GeneralizedFactorization{T}} where {T}
557+
558+
# LU Factorization
559+
struct GeneralizedLU{T,S<:AbstractArray,P<:AbstractArray,I<:Union{AbstractArray,Number}} <:
560+
GeneralizedFactorization{T}
561+
factors::S
562+
ipiv::P
563+
perm::P
564+
info::I
565+
end
566+
567+
Base.ndims(lu::GeneralizedLU) = ndims(lu.factors)
568+
569+
function GeneralizedLU(factors::S, ipiv::P, perm::P, info::I) where {S,P,I}
570+
@assert ndims(ipiv) == ndims(perm) == ndims(factors) - 1
571+
@assert ndims(info) == ndims(factors) - 2
572+
return GeneralizedLU{eltype(factors),S,P,I}(factors, ipiv, perm, info)
573+
end
574+
575+
## allow > 2 dimensions as inputs
576+
function LinearAlgebra.lu(A::AnyTracedRArray{T,2}, ::RowMaximum; kwargs...) where {T}
577+
return lu!(copy(A), RowMaximum(); kwargs...)
578+
end
579+
function LinearAlgebra.lu(
580+
A::AnyTracedRArray{T,N}, ::RowMaximum=RowMaximum(); kwargs...
581+
) where {T,N}
582+
return lu!(copy(A), RowMaximum(); kwargs...)
583+
end
584+
585+
function LinearAlgebra.lu!(A::AnyTracedRArray{T,2}, ::RowMaximum; kwargs...) where {T}
586+
return _lu_overload(A, RowMaximum(); kwargs...)
587+
end
588+
function LinearAlgebra.lu!(A::AnyTracedRArray{T,N}, ::RowMaximum; kwargs...) where {T,N}
589+
return _lu_overload(A, RowMaximum(); kwargs...)
590+
end
591+
592+
function _lu_overload(
593+
A::AnyTracedRArray{T,N}, ::RowMaximum; check::Bool=false, allowsingular::Bool=false
594+
) where {T,N}
595+
# TODO: don't ignore the check and allowsingular flags
596+
# Batching here is in the last dimensions. `Ops.lu` expects the last dimensions
597+
permdims = vcat(Int64[N - 1, N], collect(Int64, 1:(N - 2)))
598+
A = Ops.transpose(materialize_traced_array(A), permdims)
599+
factors, ipiv, perm, info = Reactant.Ops.lu(A)
600+
601+
# Permute back to the original dimensions
602+
perm_perm = vcat(N - 1, collect(Int64, 1:(N - 2)))
603+
factors = Ops.transpose(factors, invperm(permdims))
604+
ipiv = Ops.transpose(ipiv, perm_perm)
605+
perm = Ops.transpose(perm, perm_perm)
606+
return GeneralizedLU(factors, ipiv, perm, info)
607+
end
608+
609+
function LinearAlgebra.ldiv!(
610+
lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,M}
611+
) where {T,P,I,N,M}
612+
@assert N == M + 1
613+
ldiv!(lu, reshape(B, size(B, 1), 1, size(B)[2:end]...))
614+
return B
615+
end
616+
617+
function LinearAlgebra.ldiv!(
618+
lu::GeneralizedLU{T,<:AbstractArray{T,2},P,I}, B::AbstractArray{T,2}
619+
) where {T,P,I}
620+
B .= _lu_solve_core(lu.factors, B, lu.perm)
621+
return B
622+
end
623+
624+
function LinearAlgebra.ldiv!(
625+
lu::GeneralizedLU{T,<:AbstractArray{T,N},P,I}, B::AbstractArray{T,N}
626+
) where {T,P,I,N}
627+
batch_shape = size(lu.factors)[3:end]
628+
@assert batch_shape == size(B)[3:end]
629+
630+
permutation = vcat(collect(Int64, 3:N), 1, 2)
631+
632+
factors = Ops.transpose(materialize_traced_array(lu.factors), permutation)
633+
B_permuted = Ops.transpose(materialize_traced_array(B), permutation)
634+
perm = Ops.transpose(
635+
materialize_traced_array(lu.perm), vcat(collect(Int64, 2:(N - 1)), 1)
636+
)
637+
638+
res = Ops.transpose(
639+
only(
640+
Ops.batch(
641+
_lu_solve_core, [factors, B_permuted, perm], collect(Int64, batch_shape)
642+
),
643+
),
644+
invperm(permutation),
645+
)
646+
B .= res
647+
return B
648+
end
649+
650+
for f_wrapper in (LinearAlgebra.TransposeFactorization, LinearAlgebra.AdjointFactorization),
651+
aType in (:AbstractVecOrMat, :AbstractArray)
652+
653+
@eval function LinearAlgebra.ldiv!(lu::$(f_wrapper){<:Any,<:GeneralizedLU}, B::$aType)
654+
# TODO: implement this
655+
error("`$(f_wrapper)` is not supported yet for LU.")
656+
return nothing
657+
end
658+
end
659+
660+
function _lu_solve_core(factors::AbstractMatrix, B::AbstractMatrix, perm::AbstractVector)
661+
permuted_B = B[Int64.(perm), :]
662+
return UpperTriangular(factors) \ (UnitLowerTriangular(factors) \ permuted_B)
663+
end
664+
665+
# Overload \ to support batched factorization
666+
for T in (
667+
:GeneralizedFactorization,
668+
:GeneralizedTransposeFactorization,
669+
:GeneralizedAdjointFactorization,
670+
),
671+
aType in (:AbstractVecOrMat, :AbstractArray)
672+
673+
@eval Base.:(\)(F::$T, B::$aType) = _overloaded_backslash(F, B)
674+
end
675+
676+
function _overloaded_backslash(F::GeneralizedFactorization, B::AbstractArray)
677+
return ldiv!(
678+
F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B))))
679+
)
680+
end
681+
682+
function _overloaded_backslash(F::GeneralizedTransposeFactorization, B::AbstractArray)
683+
return conj!(adjoint(F.parent) \ conj.(B))
684+
end
685+
686+
function _overloaded_backslash(F::GeneralizedAdjointFactorization, B::AbstractArray)
687+
return ldiv!(
688+
F, LinearAlgebra.copy_similar(B, typeof(oneunit(eltype(F)) \ oneunit(eltype(B))))
689+
)
690+
end
691+
542692
end

test/integration/linear_algebra.jl

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,3 +319,55 @@ end
319319
end
320320
end
321321
end
322+
323+
solve_with_lu(A, b) = lu(A) \ b
324+
function solve_with_lu_batched(A::AbstractArray{T,N}, B::AbstractArray{T,N}) where {T,N}
325+
A2 = reshape(A, size(A, 1), size(A, 2), prod(size(A)[3:end]))
326+
B2 = reshape(B, size(B, 1), size(B, 2), prod(size(B)[3:end]))
327+
@assert size(A2, 3) == size(B2, 3)
328+
return reshape(
329+
stack(lu(view(A2, :, :, i)) \ view(B2, :, :, i) for i in axes(A2, 3)),
330+
size(A2, 1),
331+
size(B2, 2),
332+
size(A)[3:end]...,
333+
)
334+
end
335+
function solve_with_lu_batched(A::AbstractArray{T,N}, b::AbstractArray{T,M}) where {T,N,M}
336+
@assert N == M + 1
337+
B = reshape(b, size(b, 1), 1, size(b)[2:end]...)
338+
return dropdims(solve_with_lu_batched(A, B); dims=2)
339+
end
340+
341+
@testset "LU Factorization" begin
342+
@testset "Un-batched" begin
343+
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
344+
A = rand(T, 4, 4)
345+
A_ra = Reactant.to_rarray(A)
346+
347+
b = rand(T, 4)
348+
b_ra = Reactant.to_rarray(b)
349+
350+
B = rand(T, 4, 3)
351+
B_ra = Reactant.to_rarray(B)
352+
353+
@test @jit(solve_with_lu(A_ra, b_ra)) solve_with_lu(A, b)
354+
@test @jit(solve_with_lu(A_ra, B_ra)) solve_with_lu(A, B)
355+
end
356+
end
357+
358+
@testset "Batched" begin
359+
@testset for T in (Float32, Float64, ComplexF32, ComplexF64)
360+
A = rand(T, 4, 4, 3, 2)
361+
A_ra = Reactant.to_rarray(A)
362+
363+
b = rand(T, 4, 3, 2)
364+
b_ra = Reactant.to_rarray(b)
365+
366+
B = rand(T, 4, 5, 3, 2)
367+
B_ra = Reactant.to_rarray(B)
368+
369+
@test @jit(solve_with_lu(A_ra, b_ra)) solve_with_lu_batched(A, b)
370+
@test @jit(solve_with_lu(A_ra, B_ra)) solve_with_lu_batched(A, B)
371+
end
372+
end
373+
end

0 commit comments

Comments
 (0)