Skip to content

Commit b4c3127

Browse files
committed
Generalize Base._cat implementation on TracedRArray to typed Base._cat_t
1 parent a9fb926 commit b4c3127

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

src/TracedRArray.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -764,30 +764,30 @@ end
764764
dispatch_val(x) = x
765765
dispatch_val(::Val{D}) where {D} = D
766766

767-
function Base._cat(dims, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N}
767+
function Base._cat_t(dims, ::Type{T}, X::TracedRArray...) where {T}
768768
dims = dispatch_val(dims)
769769
@assert dims isa Integer "Support for non-integer dimensions is not implemented yet."
770770

771771
# MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
772-
A = maybe_expand_dims(A, dims)
773-
Bs = maybe_expand_dims.(Bs, (dims,))
772+
X = maybe_expand_dims.(X, (dims,))
774773

775774
catdims = Base.dims2cat(dims)
776-
shape = Base.cat_size_shape(catdims, A, Bs...)
777-
RT = Base.promote_eltype(A, Bs...)
778-
Res = TracedRArray{RT,length(shape)}(
775+
shape = Base.cat_size_shape(catdims, X...)
776+
RT = Base.promote_eltype(T, X...)
777+
778+
return TracedRArray{RT,length(shape)}(
779779
(),
780780
MLIR.IR.result(
781+
# TODO maybe we should do some conversion?
781782
MLIR.Dialects.stablehlo.concatenate(
782-
[A.mlir_data, [B.mlir_data for B in Bs]...];
783+
get_mlir_data.(X);
783784
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
784785
dimension=dims - 1, # stablehlo expects this to be zero-indexed
785786
),
786787
1,
787788
),
788789
shape,
789790
)
790-
return Res
791791
end
792792

793793
function maybe_expand_dims(x::AbstractArray{T,N}, dims) where {T,N}

0 commit comments

Comments
 (0)