@@ -764,30 +764,30 @@ end
764
764
dispatch_val (x) = x
765
765
dispatch_val (:: Val{D} ) where {D} = D
766
766
767
- function Base. _cat (dims, A :: TracedRArray{T,N } , Bs :: TracedRArray... ) where {T,N }
767
+ function Base. _cat_t (dims, :: Type{T } , X :: TracedRArray... ) where {T}
768
768
dims = dispatch_val (dims)
769
769
@assert dims isa Integer " Support for non-integer dimensions is not implemented yet."
770
770
771
771
# MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
772
- A = maybe_expand_dims (A, dims)
773
- Bs = maybe_expand_dims .(Bs, (dims,))
772
+ X = maybe_expand_dims .(X, (dims,))
774
773
775
774
catdims = Base. dims2cat (dims)
776
- shape = Base. cat_size_shape (catdims, A, Bs... )
777
- RT = Base. promote_eltype (A, Bs... )
778
- Res = TracedRArray {RT,length(shape)} (
775
+ shape = Base. cat_size_shape (catdims, X... )
776
+ RT = Base. promote_eltype (T, X... )
777
+
778
+ return TracedRArray {RT,length(shape)} (
779
779
(),
780
780
MLIR. IR. result (
781
+ # TODO maybe we should do some conversion?
781
782
MLIR. Dialects. stablehlo. concatenate (
782
- [A . mlir_data, [B . mlir_data for B in Bs] . .. ] ;
783
+ get_mlir_data .(X) ;
783
784
result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
784
785
dimension= dims - 1 , # stablehlo expects this to be zero-indexed
785
786
),
786
787
1 ,
787
788
),
788
789
shape,
789
790
)
790
- return Res
791
791
end
792
792
793
793
function maybe_expand_dims (x:: AbstractArray{T,N} , dims) where {T,N}
0 commit comments