diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 9e46f567e4..2780179098 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -761,32 +761,87 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted) return dest end -function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D} - @assert D isa Integer "Support for non-integer dimensions is not implemented yet." +dispatch_val(x) = x +dispatch_val(::Val{D}) where {D} = D - # MLIR expects the dimension `D` to be ≤ the rank of the input tensors - A = maybe_expand_dims(A, dims) - Bs = maybe_expand_dims.(Bs, (dims,)) +@inline function Base._typed_vcat( + ::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray} +) where {T} + return Base._cat_t(Val(1), T, X...) +end +@inline function Base._typed_hcat( + ::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray} +) where {T} + return Base._cat_t(Val(2), T, X...) +end + +# `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant +# generic implementation uses `typed_hcat` and `typed_vcat` which is alright +@inline function Base.typed_hvcat( + ::Type{T}, rows::Tuple{Vararg{Int}}, as::TracedRArray... +) where {T} + return invoke( + Base.typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as... + ) +end + +function Base._typed_hvncat( + T::Type, dims::NTuple{N,Int}, row_first::Bool, as::TracedRArray... +) where {N} + As = if row_first + perm = [2, 1, 3:N...] + dims = [dims[2], dims[1], dims[3:end]...] + permutedims(reshape(collect(as), dims...), perm) + else + reshape(collect(as), dims) + end + + for d in 1:N + Bs = Array{Any,N - d}(undef, size(As)[2:end]...) + + for (i, col) in + zip(eachindex(Bs), eachslice(As; dims=Tuple(2:ndims(As)), drop=true)) + # TODO row_first affects the flattening? + Bs[i] = Base._cat_t(d, T, col...) + end + + As = Bs + end + + return only(As) +end + +function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T} + dims = dispatch_val(dims) + @assert dims isa Integer "Support for non-integer dimensions is not implemented yet." + + # MLIR expects the dimension `dims` to be ≤ the rank of the input tensors + X = maybe_expand_dims.(X, (dims,)) catdims = Base.dims2cat(dims) - shape = Base.cat_size_shape(catdims, A, Bs...) - RT = Base.promote_eltype(A, Bs...) - Res = TracedRArray{RT,length(shape)}( + shape = Base.cat_size_shape(catdims, X...) + RT = Base.promote_eltype(T, X...) + + # convert to the target eltype + X = map(Base.Fix1(promote_to, TracedRArray{RT,length(shape)}), X) + + return TracedRArray{RT,length(shape)}( (), MLIR.IR.result( + # TODO maybe we should do some conversion? MLIR.Dialects.stablehlo.concatenate( - [A.mlir_data, [B.mlir_data for B in Bs]...]; + collect(get_mlir_data.(X)); result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)), - dimension=D - 1, # stablehlo expects this to be zero-indexed + dimension=dims - 1, # stablehlo expects this to be zero-indexed ), 1, ), shape, ) - return Res end -function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D} - D ≤ N && return x - return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, Val(D))) +function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N} + dims = dispatch_val(dims) + dims ≤ N && return x + return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims)) end diff --git a/src/Tracing.jl b/src/Tracing.jl index ae4f3b4c66..b3233a6447 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -380,7 +380,7 @@ function make_tracer( if haskey(seen, prev) return seen[prev] end - if mode == ArrayToConcrete && eltype(RT) <: AbstractFloat + if mode == ArrayToConcrete && eltype(RT) <: Union{AbstractFloat,Integer} return seen[prev] = ConcreteRArray(prev) end TT = traced_type(eltype(RT), (), Val(mode)) diff --git a/test/basic.jl b/test/basic.jl index 4467adf8f4..368f22d728 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -210,20 +210,73 @@ end end @testset "concatenation" begin - x = ones(2, 4, 3) - x_concrete = Reactant.to_rarray(x) - - cat1(x) = vcat(x, x, x) - cat2(x) = hcat(x, x, x) - cat3(x) = cat(x, x, x; dims=Val(3)) - - cat1_compiled = @compile cat1(x_concrete) - cat2_compiled = @compile cat2(x_concrete) - cat3_compiled = @compile cat3(x_concrete) - - @test cat1(x) ≈ cat1_compiled(x_concrete) - @test cat2(x) ≈ cat2_compiled(x_concrete) - @test cat3(x) ≈ cat3_compiled(x_concrete) + @testset "$(ndims(x))-dim" for x in [ + fill(true), + [true, false], + [true false], + [true true; true false], + [ + true true true true; true true true false;;; + true true false true; true true false false;;; + true false true true; true false true false + ], + ] + x_concrete = Reactant.to_rarray(x) + + # NOTE [,,,] is a call to `vect`, not `*cat` + # f = Reactant.compile((x_concrete,)) do x + # return [x, x, x] + # end + # @test f(x_concrete) ≈ ones(3) + + # vcat + test_vcat(x) = [x; x; x] + f = @compile test_vcat(x_concrete) + @test f(x_concrete) == test_vcat(x) + @test eltype(f(x_concrete)) === Bool + + # hcat + test_hcat(x) = [x x x] + f = @compile test_hcat(x_concrete) + @test f(x_concrete) == test_hcat(x) + @test eltype(f(x_concrete)) === Bool + + # hvcat + test_hvcat(x) = [x x x; x x x] + f = @compile test_hvcat(x_concrete) + @test f(x_concrete) == test_hvcat(x) + @test eltype(f(x_concrete)) === Bool + + # hvncat + test_hvncat(x) = [x x x; x x x;;; x x x; x x x] + f = @compile test_hvncat(x_concrete) + @test f(x_concrete) == test_hvncat(x) + @test eltype(f(x_concrete)) === Bool + + # typed_vcat + test_typed_vcat(x) = Int[x; x; x] + f = @compile test_typed_vcat(x_concrete) + @test f(x_concrete) == test_typed_vcat(x) + @test eltype(f(x_concrete)) === Int + + # typed_hcat + test_typed_hcat(x) = Int[x x x] + f = @compile test_typed_hcat(x_concrete) + @test f(x_concrete) == test_typed_hcat(x) + @test eltype(f(x_concrete)) === Int + + # typed_hvcat + test_typed_hvcat(x) = Int[x x x; x x x] + f = @compile test_typed_hvcat(x_concrete) + @test f(x_concrete) == test_typed_hvcat(x) + @test eltype(f(x_concrete)) === Int + + # typed_hvncat + test_typed_hvncat(x) = Int[x x x; x x x;;; x x x; x x x] + f = @compile test_typed_hvncat(x_concrete) + @test f(x_concrete) == test_typed_hvncat(x) + @test eltype(f(x_concrete)) === Int + end end function update_on_copy(x)