Skip to content

Commit 2cdbb83

Browse files
committed
Remove Val constraint on Base._cat signature
1 parent deefd18 commit 2cdbb83

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

src/TracedRArray.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -759,10 +759,14 @@ function _copyto!(dest::TracedRArray, bc::Broadcasted)
759759
return dest
760760
end
761761

762-
function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N,D}
763-
@assert D isa Integer "Support for non-integer dimensions is not implemented yet."
762+
dispatch_val(x) = x
763+
dispatch_val(::Val{D}) where {D} = D
764764

765-
# MLIR expects the dimension `D` to be ≤ the rank of the input tensors
765+
function Base._cat(dims, A::TracedRArray{T,N}, Bs::TracedRArray...) where {T,N}
766+
dims = dispatch_val(dims)
767+
@assert dims isa Integer "Support for non-integer dimensions is not implemented yet."
768+
769+
# MLIR expects the dimension `dims` to be ≤ the rank of the input tensors
766770
A = maybe_expand_dims(A, dims)
767771
Bs = maybe_expand_dims.(Bs, (dims,))
768772

@@ -775,7 +779,7 @@ function Base._cat(dims::Val{D}, A::TracedRArray{T,N}, Bs::TracedRArray...) wher
775779
MLIR.Dialects.stablehlo.concatenate(
776780
[A.mlir_data, [B.mlir_data for B in Bs]...];
777781
result_0=MLIR.IR.TensorType(shape, MLIR.IR.Type(RT)),
778-
dimension=D - 1, # stablehlo expects this to be zero-indexed
782+
dimension=dims - 1, # stablehlo expects this to be zero-indexed
779783
),
780784
1,
781785
),

0 commit comments

Comments
 (0)