@@ -455,18 +455,41 @@ function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
455
455
return A
456
456
end
457
457
458
+ function Base. _cat (dims:: Val{D} , A:: TracedRArray{T,N} , Bs:: TracedRArray... ) where {T,N,D}
459
+ @assert D isa Integer " Support for non-integer dimensions is not implemented yet."
460
+
461
+ # MLIR expects the dimension `D` to be ≤ the rank of the input tensors
462
+ A = maybe_expand_dims (A, dims)
463
+ Bs = maybe_expand_dims .(Bs, (dims,))
464
+
465
+ catdims = Base. dims2cat (dims)
466
+ shape = Base. cat_size_shape (catdims, A, Bs... )
467
+ RT = Base. promote_eltype (A, Bs... )
468
+ Res = TracedRArray {RT,length(shape)} (
469
+ (),
470
+ MLIR. IR. result (
471
+ MLIR. Dialects. stablehlo. concatenate (
472
+ [A. mlir_data, [B. mlir_data for B in Bs]. .. ];
473
+ result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
474
+ dimension= D - 1 , # stablehlo expects this to be zero-indexed
475
+ ),
476
+ 1 ,
477
+ ),
478
+ shape,
479
+ )
480
+ return Res
481
+ end
482
+
483
+ function maybe_expand_dims (x:: AbstractArray{T,N} , :: Val{D} ) where {T,N,D}
484
+ D ≤ N && return x
485
+ return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , Val (D)))
486
+ end
487
+
458
488
struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
459
489
460
490
AbstractReactantArrayStyle (:: Val{N} ) where {N} = AbstractReactantArrayStyle {N} ()
461
491
AbstractReactantArrayStyle {M} (:: Val{N} ) where {N,M} = AbstractReactantArrayStyle {N} ()
462
492
463
- # function Broadcast.materialize(bc::Broadcasted)
464
- # @show bc
465
- # inst = instantiate(bc)
466
- # @show inst
467
- # copy(inst)
468
- # end
469
-
470
493
function BroadcastStyle (:: Type{<:AnyTracedRArray{T,N}} ) where {T,N}
471
494
return AbstractReactantArrayStyle {N} ()
472
495
end
@@ -628,33 +651,3 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
628
651
dest. mlir_data = res. mlir_data
629
652
return dest
630
653
end
631
-
632
- function Base. _cat (dims:: Val{D} , A:: TracedRArray{T,N} , Bs:: TracedRArray... ) where {T,N,D}
633
- @assert D isa Integer " Support for non-integer dimensions is not implemented yet."
634
-
635
- # MLIR expects the dimension `D` to be ≤ the rank of the input tensors
636
- A = maybe_expand_dims (A, dims)
637
- Bs = maybe_expand_dims .(Bs, (dims,))
638
-
639
- catdims = Base. dims2cat (dims)
640
- shape = Base. cat_size_shape (catdims, A, Bs... )
641
- RT = Base. promote_eltype (A, Bs... )
642
- Res = TracedRArray {RT,length(shape)} (
643
- (),
644
- MLIR. IR. result (
645
- MLIR. Dialects. stablehlo. concatenate (
646
- [A. mlir_data, [B. mlir_data for B in Bs]. .. ];
647
- result_0= MLIR. IR. TensorType (shape, MLIR. IR. Type (RT)),
648
- dimension= D - 1 , # stablehlo expects this to be zero-indexed
649
- ),
650
- 1 ,
651
- ),
652
- shape,
653
- )
654
- return Res
655
- end
656
-
657
- function maybe_expand_dims (x:: AbstractArray{T,N} , :: Val{D} ) where {T,N,D}
658
- D ≤ N && return x
659
- return reshape (x, ntuple (i -> i ≤ N ? size (x, i) : 1 , Val (D)))
660
- end
0 commit comments