Skip to content

Generalize Base._cat to non-Val, typed Base._cat_t and implement typed_hcat, typed_vcat, typed_hvcat, typed_hvncat #163

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 69 additions & 14 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -761,32 +761,87 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
return dest
end

function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D}
@assert D isa Integer "Support for non-integer dimensions is not implemented yet."
dispatch_val(x) = x
dispatch_val(::Val{D}) where {D} = D

# MLIR expects the dimension `D` to be ≤ the rank of the input tensors
A = maybe_expand_dims(A, dims)
Bs = maybe_expand_dims.(Bs, (dims,))
@inline function Base._typed_vcat(
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
) where {T}
return Base._cat_t(Val(1), T, X...)
end
@inline function Base._typed_hcat(
::Type{T}, X::Base.AbstractVecOrTuple{<:TracedRArray}
) where {T}
return Base._cat_t(Val(2), T, X...)
end

# `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant
# generic implementation uses `typed_hcat` and `typed_vcat` which is alright
@inline function Base.typed_hvcat(
::Type{T}, rows::Tuple{Vararg{Int}}, as::TracedRArray...
) where {T}
return invoke(
Base.typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
)
end

function Base._typed_hvncat(
T::Type, dims::NTuple{N,Int}, row_first::Bool, as::TracedRArray...
) where {N}
As = if row_first
perm = [2, 1, 3:N...]
dims = [dims[2], dims[1], dims[3:end]...]
permutedims(reshape(collect(as), dims...), perm)
else
reshape(collect(as), dims)
end

for d in 1:N
Bs = Array{Any,N - d}(undef, size(As)[2:end]...)

for (i, col) in
zip(eachindex(Bs), eachslice(As; dims=Tuple(2:ndims(As)), drop=true))
# TODO row_first affects the flattening?
Bs[i] = Base._cat_t(d, T, col...)
end

As = Bs
end

return only(As)
end

function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T}
dims = dispatch_val(dims)
@assert dims isa Integer "Support for non-integer dimensions is not implemented yet."

# MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
X = maybe_expand_dims.(X, (dims,))

catdims = Base.dims2cat(dims)
shape = Base.cat_size_shape(catdims, A, Bs...)
RT = Base.promote_eltype(A, Bs...)
Res = TracedRArray{RT,length(shape)}(
shape = Base.cat_size_shape(catdims, X...)
RT = Base.promote_eltype(T, X...)

# convert to the target eltype
X = map(Base.Fix1(promote_to, TracedRArray{RT,length(shape)}), X)

return TracedRArray{RT,length(shape)}(
(),
MLIR.IR.result(
# TODO maybe we should do some conversion?
MLIR.Dialects.stablehlo.concatenate(
[A.mlir_data, [B.mlir_data for B in Bs]...];
collect(get_mlir_data.(X));
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
dimension=D - 1, # stablehlo expects this to be zero-indexed
dimension=dims - 1, # stablehlo expects this to be zero-indexed
),
1,
),
shape,
)
return Res
end

function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D}
D ≤ N && return x
return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, Val(D)))
function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N}
dims = dispatch_val(dims)
dims ≤ N && return x
return reshape(x, ntuple(i -> i ≤ N ? size(x, i) : 1, dims))
end
2 changes: 1 addition & 1 deletion src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ function make_tracer(
if haskey(seen, prev)
return seen[prev]
end
if mode == ArrayToConcrete && eltype(RT) <: AbstractFloat
if mode == ArrayToConcrete && eltype(RT) <: Union{AbstractFloat,Integer}
return seen[prev] = ConcreteRArray(prev)
end
TT = traced_type(eltype(RT), (), Val(mode))
Expand Down
81 changes: 67 additions & 14 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -210,20 +210,73 @@ end
end

@testset "concatenation" begin
x = ones(2, 4, 3)
x_concrete = Reactant.to_rarray(x)

cat1(x) = vcat(x, x, x)
cat2(x) = hcat(x, x, x)
cat3(x) = cat(x, x, x; dims=Val(3))

cat1_compiled = @compile cat1(x_concrete)
cat2_compiled = @compile cat2(x_concrete)
cat3_compiled = @compile cat3(x_concrete)

@test cat1(x) ≈ cat1_compiled(x_concrete)
@test cat2(x) ≈ cat2_compiled(x_concrete)
@test cat3(x) ≈ cat3_compiled(x_concrete)
@testset "$(ndims(x))-dim" for x in [
fill(true),
[true, false],
[true false],
[true true; true false],
[
true true true true; true true true false;;;
true true false true; true true false false;;;
true false true true; true false true false
],
]
x_concrete = Reactant.to_rarray(x)

# NOTE [,,,] is a call to `vect`, not `*cat`
# f = Reactant.compile((x_concrete,)) do x
# return [x, x, x]
# end
# @test f(x_concrete) ≈ ones(3)

# vcat
test_vcat(x) = [x; x; x]
f = @compile test_vcat(x_concrete)
@test f(x_concrete) == test_vcat(x)
@test eltype(f(x_concrete)) === Bool

# hcat
test_hcat(x) = [x x x]
f = @compile test_hcat(x_concrete)
@test f(x_concrete) == test_hcat(x)
@test eltype(f(x_concrete)) === Bool

# hvcat
test_hvcat(x) = [x x x; x x x]
f = @compile test_hvcat(x_concrete)
@test f(x_concrete) == test_hvcat(x)
@test eltype(f(x_concrete)) === Bool

# hvncat
test_hvncat(x) = [x x x; x x x;;; x x x; x x x]
f = @compile test_hvncat(x_concrete)
@test f(x_concrete) == test_hvncat(x)
@test eltype(f(x_concrete)) === Bool

# typed_vcat
test_typed_vcat(x) = Int[x; x; x]
f = @compile test_typed_vcat(x_concrete)
@test f(x_concrete) == test_typed_vcat(x)
@test eltype(f(x_concrete)) === Int

# typed_hcat
test_typed_hcat(x) = Int[x x x]
f = @compile test_typed_hcat(x_concrete)
@test f(x_concrete) == test_typed_hcat(x)
@test eltype(f(x_concrete)) === Int

# typed_hvcat
test_typed_hvcat(x) = Int[x x x; x x x]
f = @compile test_typed_hvcat(x_concrete)
@test f(x_concrete) == test_typed_hvcat(x)
@test eltype(f(x_concrete)) === Int

# typed_hvncat
test_typed_hvncat(x) = Int[x x x; x x x;;; x x x; x x x]
f = @compile test_typed_hvncat(x_concrete)
@test f(x_concrete) == test_typed_hvncat(x)
@test eltype(f(x_concrete)) === Int
end
end

function update_on_copy(x)
Expand Down
Loading