@@ -759,10 +759,14 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
759
759
return dest
760
760
end
761
761
762
- function Base . _cat (dims :: Val{D} , A :: TracedRArray{T,N} , Bs :: TracedRArray... ) where {T,N,D}
763
- @assert D isa Integer " Support for non-integer dimensions is not implemented yet. "
762
+ dispatch_val (x) = x
763
+ dispatch_val ( :: Val{D} ) where {D} = D
764
764
765
- # MLIR expects the dimension `D` to be ≤ the rank of the input tensors
765
+ function Base. _cat (dims, A:: TracedRArray{T,N} , Bs:: TracedRArray... ) where {T,N}
766
+ dims = dispatch_val (dims)
767
+ @assert dims isa Integer " Support for non-integer dimensions is not implemented yet."
768
+
769
+ # MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
766
770
A = maybe_expand_dims (A, dims)
767
771
Bs = maybe_expand_dims .(Bs, (dims,))
768
772
@@ -775,7 +779,7 @@ function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) wher
775
779
MLIR. Dialects. stablehlo. concatenate (
776
780
[A. mlir_data, [B. mlir_data for B in Bs]. .. ];
777
781
result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
778
- dimension= D - 1 , # stablehlo expects this to be zero-indexed
782
+ dimension= dims - 1 , # stablehlo expects this to be zero-indexed
779
783
),
780
784
1 ,
781
785
),
0 commit comments