Skip to content

Commit 388e6a2

Browse files
committed
refactor: move code a bit
1 parent 3648c0a commit 388e6a2

File tree

1 file changed

+30
-37
lines changed

1 file changed

+30
-37
lines changed

src/TracedRArray.jl

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -455,18 +455,41 @@ function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
455455
return A
456456
end
457457

458+
function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D}
459+
@assert D isa Integer "Support for non-integer dimensions is not implemented yet."
460+
461+
# MLIR expects the dimension `D` to be ≤ the rank of the input tensors
462+
A = maybe_expand_dims(A, dims)
463+
Bs = maybe_expand_dims.(Bs, (dims,))
464+
465+
catdims = Base.dims2cat(dims)
466+
shape = Base.cat_size_shape(catdims, A, Bs...)
467+
RT = Base.promote_eltype(A, Bs...)
468+
Res = TracedRArray{RT,length(shape)}(
469+
(),
470+
MLIR.IR.result(
471+
MLIR.Dialects.stablehlo.concatenate(
472+
[A.mlir_data, [B.mlir_data for B in Bs]...];
473+
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
474+
dimension=D - 1, # stablehlo expects this to be zero-indexed
475+
),
476+
1,
477+
),
478+
shape,
479+
)
480+
return Res
481+
end
482+
483+
function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D}
484+
D N && return x
485+
return reshape(x, ntuple(i -> i N ? size(x, i) : 1, Val(D)))
486+
end
487+
458488
struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
459489

460490
AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}()
461491
AbstractReactantArrayStyle{M}(::Val{N}) where {N,M} = AbstractReactantArrayStyle{N}()
462492

463-
# function Broadcast.materialize(bc::Broadcasted)
464-
# @show bc
465-
# inst = instantiate(bc)
466-
# @show inst
467-
# copy(inst)
468-
# end
469-
470493
function BroadcastStyle(::Type{<:AnyTracedRArray{T,N}}) where {T,N}
471494
return AbstractReactantArrayStyle{N}()
472495
end
@@ -628,33 +651,3 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
628651
dest.mlir_data = res.mlir_data
629652
return dest
630653
end
631-
632-
function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D}
633-
@assert D isa Integer "Support for non-integer dimensions is not implemented yet."
634-
635-
# MLIR expects the dimension `D` to be ≤ the rank of the input tensors
636-
A = maybe_expand_dims(A, dims)
637-
Bs = maybe_expand_dims.(Bs, (dims,))
638-
639-
catdims = Base.dims2cat(dims)
640-
shape = Base.cat_size_shape(catdims, A, Bs...)
641-
RT = Base.promote_eltype(A, Bs...)
642-
Res = TracedRArray{RT,length(shape)}(
643-
(),
644-
MLIR.IR.result(
645-
MLIR.Dialects.stablehlo.concatenate(
646-
[A.mlir_data, [B.mlir_data for B in Bs]...];
647-
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
648-
dimension=D - 1, # stablehlo expects this to be zero-indexed
649-
),
650-
1,
651-
),
652-
shape,
653-
)
654-
return Res
655-
end
656-
657-
function maybe_expand_dims(x::AbstractArray{T,N}, ::Val{D}) where {T,N,D}
658-
D N && return x
659-
return reshape(x, ntuple(i -> i N ? size(x, i) : 1, Val(D)))
660-
end

0 commit comments

Comments
 (0)