@@ -464,36 +464,6 @@ function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
464
464
return A
465
465
end
466
466
467
- function Base. _cat (dims:: Val{D} , A:: TracedRArray{T,N} , Bs:: TracedRArray... ) where {T,N,D}
468
- @assert D isa Integer " Support for non-integer dimensions is not implemented yet."
469
-
470
- # MLIR expects the dimension `D` to be ≤ the rank of the input tensors
471
- A = maybe_expand_dims (A, dims)
472
- Bs = maybe_expand_dims .(Bs, (dims,))
473
-
474
- catdims = Base. dims2cat (dims)
475
- shape = Base. cat_size_shape (catdims, A, Bs... )
476
- RT = Base. promote_eltype (A, Bs... )
477
- Res = TracedRArray {RT,length(shape)} (
478
- (),
479
- MLIR. IR. result (
480
- MLIR. Dialects. stablehlo. concatenate (
481
- [A. mlir_data, [B. mlir_data for B in Bs]. .. ];
482
- result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
483
- dimension= D - 1 , # stablehlo expects this to be zero-indexed
484
- ),
485
- 1 ,
486
- ),
487
- shape,
488
- )
489
- return Res
490
- end
491
-
492
- function maybe_expand_dims (x:: AbstractArray{T,N} , :: Val{D} ) where {T,N,D}
493
- D ≤ N && return x
494
- return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , Val (D)))
495
- end
496
-
497
467
struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
498
468
499
469
AbstractReactantArrayStyle (:: Val{N} ) where {N} = AbstractReactantArrayStyle {N} ()
@@ -648,3 +618,88 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
648
618
dest. mlir_data = res. mlir_data
649
619
return dest
650
620
end
621
+
622
+ dispatch_val (x) = x
623
+ dispatch_val (:: Val{D} ) where {D} = D
624
+
625
+ @inline function Base. _typed_vcat (
626
+ :: Type{T} , X:: Base.AbstractVecOrTuple{<:TracedRArray}
627
+ ) where {T}
628
+ return Base. _cat_t (Val (1 ), T, X... )
629
+ end
630
+ @inline function Base. _typed_hcat (
631
+ :: Type{T} , X:: Base.AbstractVecOrTuple{<:TracedRArray}
632
+ ) where {T}
633
+ return Base. _cat_t (Val (2 ), T, X... )
634
+ end
635
+
636
+ # `Base.typed_hvcat` is overloaded for `AbstractVecOrMat` using `setindex!` that breaks Reactant
637
+ # generic implementation uses `typed_hcat` and `typed_vcat` which is alright
638
+ @inline function Base. typed_hvcat (
639
+ :: Type{T} , rows:: Tuple{Vararg{Int}} , as:: TracedRArray...
640
+ ) where {T}
641
+ return invoke (
642
+ Base. typed_hvcat, Tuple{Type{T},Tuple{Vararg{Int}},Vararg{Any}}, T, rows, as...
643
+ )
644
+ end
645
+
646
+ function Base. _typed_hvncat (
647
+ T:: Type , dims:: NTuple{N,Int} , row_first:: Bool , as:: TracedRArray...
648
+ ) where {N}
649
+ As = if row_first
650
+ perm = [2 , 1 , 3 : N... ]
651
+ dims = [dims[2 ], dims[1 ], dims[3 : end ]. .. ]
652
+ permutedims (reshape (collect (as), dims... ), perm)
653
+ else
654
+ reshape (collect (as), dims)
655
+ end
656
+
657
+ for d in 1 : N
658
+ Bs = Array {Any,N - d} (undef, size (As)[2 : end ]. .. )
659
+
660
+ for (i, col) in
661
+ zip (eachindex (Bs), eachslice (As; dims= Tuple (2 : ndims (As)), drop= true ))
662
+ # TODO row_first affects the flattening?
663
+ Bs[i] = Base. _cat_t (d, T, col... )
664
+ end
665
+
666
+ As = Bs
667
+ end
668
+
669
+ return only (As)
670
+ end
671
+
672
+ function Base. _cat_t (dims, :: Type{T} , X:: TracedRArray... ) where {T}
673
+ dims = dispatch_val (dims)
674
+ @assert dims isa Integer " Support for non-integer dimensions is not implemented yet."
675
+
676
+ # MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
677
+ X = maybe_expand_dims .(X, (dims,))
678
+
679
+ catdims = Base. dims2cat (dims)
680
+ shape = Base. cat_size_shape (catdims, X... )
681
+ RT = Base. promote_eltype (T, X... )
682
+
683
+ # convert to the target eltype
684
+ X = map (Base. Fix1 (promote_to, TracedRArray{RT,length (shape)}), X)
685
+
686
+ return TracedRArray {RT,length(shape)} (
687
+ (),
688
+ MLIR. IR. result (
689
+ # TODO maybe we should do some conversion?
690
+ MLIR. Dialects. stablehlo. concatenate (
691
+ collect (get_mlir_data .(X));
692
+ result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
693
+ dimension= dims - 1 , # stablehlo expects this to be zero-indexed
694
+ ),
695
+ 1 ,
696
+ ),
697
+ shape,
698
+ )
699
+ end
700
+
701
+ function maybe_expand_dims (x:: AbstractArray{T,N} , dims) where {T,N}
702
+ dims = dispatch_val (dims)
703
+ dims ≤ N && return x
704
+ return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , dims))
705
+ end
0 commit comments