Skip to content

Commit 500d12f

Browse files
authored
Merge branch 'main' into ap/scalar
2 parents 944dca8 + f2c0e8a commit 500d12f

File tree

2 files changed

+152
-44
lines changed

2 files changed

+152
-44
lines changed

src/TracedRArray.jl

Lines changed: 85 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -464,36 +464,6 @@ function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
464464
return A
465465
end
466466

467-
function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D}
468-
@assert D isa Integer "Support for non-integer dimensions is not implemented yet."
469-
470-
# MLIR expects the dimension `D` to be ≤ the rank of the input tensors
471-
A = maybe_expand_dims(A, dims)
472-
Bs = maybe_expand_dims.(Bs, (dims,))
473-
474-
catdims = Base.dims2cat(dims)
475-
shape = Base.cat_size_shape(catdims, A, Bs...)
476-
RT = Base.promote_eltype(A, Bs...)
477-
Res = TracedRArray{RT,length(shape)}(
478-
(),
479-
MLIR.IR.result(
480-
MLIR.Dialects.stablehlo.concatenate(
481-
[A.mlir_data, [B.mlir_data for B in Bs]...];
482-
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
483-
dimension=D - 1, # stablehlo expects this to be zero-indexed
484-
),
485-
1,
486-
),
487-
shape,
488-
)
489-
return Res
490-
end
491-
492-
function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D}
493-
D N && return x
494-
return reshape(x, ntuple(i -> i N ? size(x, i) : 1, Val(D)))
495-
end
496-
497467
struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
498468

499469
AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}()
@@ -648,3 +618,88 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
648618
dest.mlir_data = res.mlir_data
649619
return dest
650620
end
621+
622+
dispatch_val(x) = x
623+
dispatch_val(::Val{D}) where {D} = D
624+
625+
@inline function Base._typed_vcat(
626+
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
627+
) where {T}
628+
return Base._cat_t(Val(1), T, X...)
629+
end
630+
@inline function Base._typed_hcat(
631+
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
632+
) where {T}
633+
return Base._cat_t(Val(2), T, X...)
634+
end
635+
636+
# `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant
637+
# generic implementation uses `typed_hcat` and `typed_vcat` which is alright
638+
@inline function Base.typed_hvcat(
639+
::Type{T}, rows::Tuple{Vararg{Int}}, as::TracedRArray...
640+
) where {T}
641+
return invoke(
642+
Base.typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
643+
)
644+
end
645+
646+
function Base._typed_hvncat(
647+
T::Type, dims::NTuple{N,Int}, row_first::Bool, as::TracedRArray...
648+
) where {N}
649+
As = if row_first
650+
perm = [2, 1, 3:N...]
651+
dims = [dims[2], dims[1], dims[3:end]...]
652+
permutedims(reshape(collect(as), dims...), perm)
653+
else
654+
reshape(collect(as), dims)
655+
end
656+
657+
for d in 1:N
658+
Bs = Array{Any,N - d}(undef, size(As)[2:end]...)
659+
660+
for (i, col) in
661+
zip(eachindex(Bs), eachslice(As; dims=Tuple(2:ndims(As)), drop=true))
662+
# TODO row_first affects the flattening?
663+
Bs[i] = Base._cat_t(d, T, col...)
664+
end
665+
666+
As = Bs
667+
end
668+
669+
return only(As)
670+
end
671+
672+
function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T}
673+
dims = dispatch_val(dims)
674+
@assert dims isa Integer "Support for non-integer dimensions is not implemented yet."
675+
676+
# MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
677+
X = maybe_expand_dims.(X, (dims,))
678+
679+
catdims = Base.dims2cat(dims)
680+
shape = Base.cat_size_shape(catdims, X...)
681+
RT = Base.promote_eltype(T, X...)
682+
683+
# convert to the target eltype
684+
X = map(Base.Fix1(promote_to, TracedRArray{RT,length(shape)}), X)
685+
686+
return TracedRArray{RT,length(shape)}(
687+
(),
688+
MLIR.IR.result(
689+
# TODO maybe we should do some conversion?
690+
MLIR.Dialects.stablehlo.concatenate(
691+
collect(get_mlir_data.(X));
692+
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
693+
dimension=dims - 1, # stablehlo expects this to be zero-indexed
694+
),
695+
1,
696+
),
697+
shape,
698+
)
699+
end
700+
701+
function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N}
702+
dims = dispatch_val(dims)
703+
dims N && return x
704+
return reshape(x, ntuple(i -> i N ? size(x, i) : 1, dims))
705+
end

test/basic.jl

Lines changed: 67 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -218,20 +218,73 @@ end
218218
end
219219

220220
@testset "concatenation" begin
221-
x = ones(2, 4, 3)
222-
x_concrete = Reactant.to_rarray(x)
223-
224-
cat1(x) = vcat(x, x, x)
225-
cat2(x) = hcat(x, x, x)
226-
cat3(x) = cat(x, x, x; dims=Val(3))
227-
228-
cat1_compiled = @compile cat1(x_concrete)
229-
cat2_compiled = @compile cat2(x_concrete)
230-
cat3_compiled = @compile cat3(x_concrete)
231-
232-
@test cat1(x) cat1_compiled(x_concrete)
233-
@test cat2(x) cat2_compiled(x_concrete)
234-
@test cat3(x) cat3_compiled(x_concrete)
221+
@testset "$(ndims(x))-dim" for x in [
222+
fill(true),
223+
[true, false],
224+
[true false],
225+
[true true; true false],
226+
[
227+
true true true true; true true true false;;;
228+
true true false true; true true false false;;;
229+
true false true true; true false true false
230+
],
231+
]
232+
x_concrete = Reactant.to_rarray(x)
233+
234+
# NOTE [,,,] is a call to `vect`, not `*cat`
235+
# f = Reactant.compile((x_concrete,)) do x
236+
# return [x, x, x]
237+
# end
238+
# @test f(x_concrete) ≈ ones(3)
239+
240+
# vcat
241+
test_vcat(x) = [x; x; x]
242+
f = @compile test_vcat(x_concrete)
243+
@test f(x_concrete) == test_vcat(x)
244+
@test eltype(f(x_concrete)) === Bool
245+
246+
# hcat
247+
test_hcat(x) = [x x x]
248+
f = @compile test_hcat(x_concrete)
249+
@test f(x_concrete) == test_hcat(x)
250+
@test eltype(f(x_concrete)) === Bool
251+
252+
# hvcat
253+
test_hvcat(x) = [x x x; x x x]
254+
f = @compile test_hvcat(x_concrete)
255+
@test f(x_concrete) == test_hvcat(x)
256+
@test eltype(f(x_concrete)) === Bool
257+
258+
# hvncat
259+
test_hvncat(x) = [x x x; x x x;;; x x x; x x x]
260+
f = @compile test_hvncat(x_concrete)
261+
@test f(x_concrete) == test_hvncat(x)
262+
@test eltype(f(x_concrete)) === Bool
263+
264+
# typed_vcat
265+
test_typed_vcat(x) = Int[x; x; x]
266+
f = @compile test_typed_vcat(x_concrete)
267+
@test f(x_concrete) == test_typed_vcat(x)
268+
@test eltype(f(x_concrete)) === Int
269+
270+
# typed_hcat
271+
test_typed_hcat(x) = Int[x x x]
272+
f = @compile test_typed_hcat(x_concrete)
273+
@test f(x_concrete) == test_typed_hcat(x)
274+
@test eltype(f(x_concrete)) === Int
275+
276+
# typed_hvcat
277+
test_typed_hvcat(x) = Int[x x x; x x x]
278+
f = @compile test_typed_hvcat(x_concrete)
279+
@test f(x_concrete) == test_typed_hvcat(x)
280+
@test eltype(f(x_concrete)) === Int
281+
282+
# typed_hvncat
283+
test_typed_hvncat(x) = Int[x x x; x x x;;; x x x; x x x]
284+
f = @compile test_typed_hvncat(x_concrete)
285+
@test f(x_concrete) == test_typed_hvncat(x)
286+
@test eltype(f(x_concrete)) === Int
287+
end
235288
end
236289

237290
function update_on_copy(x)

0 commit comments

Comments
 (0)