Skip to content

Commit 7d40c5f

Browse files
committed
feat: handle cat/hcat/vcat
1 parent 2f661e1 commit 7d40c5f

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

src/TracedRArray.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,3 +758,23 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
758758
dest.mlir_data = res.mlir_data
759759
return dest
760760
end
761+
762+
function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D}
763+
@assert D isa Integer "Support for non-integer dimensions is not implemented yet."
764+
catdims = Base.dims2cat(dims)
765+
shape = Base.cat_size_shape(catdims, A, Bs...)
766+
RT = Base.promote_eltype(A, Bs...)
767+
Res = TracedRArray{RT,length(shape)}(
768+
(),
769+
MLIR.IR.result(
770+
MLIR.Dialects.stablehlo.concatenate(
771+
[A.mlir_data, [B.mlir_data for B in Bs]...];
772+
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
773+
dimension=D - 1, # stablehlo expects this to be zero-indexed
774+
),
775+
1,
776+
),
777+
shape,
778+
)
779+
return Res
780+
end

test/basic.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,20 @@ end
195195
@test var_fn3(x) var_fn3_compiled(x_ca)
196196
@test var_fn4(x) var_fn4_compiled(x_ca)
197197
end
198+
199+
@testset "concatenation" begin
200+
x = ones(2, 4, 3)
201+
x_concrete = Reactant.to_rarray(x)
202+
203+
cat1(x) = vcat(x, x, x)
204+
cat2(x) = hcat(x, x, x)
205+
cat3(x) = cat(x, x, x; dims=Val(3))
206+
207+
cat1_compiled = Reactant.compile(cat1, (x_concrete,))
208+
cat2_compiled = Reactant.compile(cat2, (x_concrete,))
209+
cat3_compiled = Reactant.compile(cat3, (x_concrete,))
210+
211+
@test cat1(x) cat1_compiled(x_concrete)
212+
@test cat2(x) cat2_compiled(x_concrete)
213+
@test cat3(x) cat3_compiled(x_concrete)
214+
end

0 commit comments

Comments
 (0)