Skip to content

Commit 3648c0a

Browse files
committed
feat: handle broadcast_preserving_zero_d in a generic fashion
1 parent 40af781 commit 3648c0a

File tree

3 files changed

+36
-7
lines changed

3 files changed

+36
-7
lines changed

src/ConcreteRArray.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,6 @@ function Base.convert(::Type{T}, x::ConcreteRArray{T,0}) where {T}
7474
return to_float(x)
7575
end
7676

77-
function Base.promote_rule(::Type{<:RArray{T1,0}}, ::Type{T2}) where {T1,T2}
78-
return Base.promote_rule(T1, T2)
79-
end
80-
8177
for jlop in (:(Base.isless), :(Base.:+), :(Base.:-), :(Base.:*), :(Base.:/), :(Base.:^))
8278
@eval begin
8379
function $jlop(x::ConcreteRArray{T,0}, y::ConcreteRArray{U,0}) where {T,U}

src/TracedRArray.jl

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ and require expensive copies and synchronization each time and therefore should
6666
return TracedRNumber{T}((), res2)
6767
end
6868

69+
function Base.getindex(a::TracedRArray{T,0}) where {T}
70+
return TracedRNumber{T}((), a.mlir_data)
71+
end
72+
6973
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
7074
indices = [i isa Colon ? (1:size(a, idx)) : i for (idx, i) in enumerate(indices)]
7175
res = MLIR.IR.result(
@@ -222,7 +226,12 @@ function elem_apply(
222226
end
223227

224228
function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
225-
all(iszero ndims, args) && return f(args...)
229+
if all(iszero ndims, args)
230+
scalar_args = map(args) do arg
231+
return promote_to(TracedRNumber{eltype(arg)}, arg)
232+
end
233+
return f(scalar_args...)
234+
end
226235

227236
fnwrap, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
228237
f, args, (), string(f) * "_broadcast_scalar", false; toscalar=true
@@ -440,6 +449,12 @@ function Base.fill!(A::TracedRArray{T,N}, x) where {T,N}
440449
return A
441450
end
442451

452+
function Base.fill!(A::TracedRArray{T,N}, x::TracedRNumber{T2}) where {T,N,T2}
453+
bcast = broadcast_to_size(promote_to(TracedRNumber{T}, x), size(A))
454+
A.mlir_data = bcast.mlir_data
455+
return A
456+
end
457+
443458
struct AbstractReactantArrayStyle{N} <: Base.Broadcast.AbstractArrayStyle{N} end
444459

445460
AbstractReactantArrayStyle(::Val{N}) where {N} = AbstractReactantArrayStyle{N}()
@@ -458,7 +473,14 @@ end
458473

459474
function Base.similar(
460475
bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{T}, dims
461-
) where {T,N}
476+
) where {T<:ReactantPrimitives,N}
477+
@assert N isa Int
478+
return TracedRArray{T,N}((), nothing, map(length, dims))
479+
end
480+
481+
function Base.similar(
482+
bc::Broadcasted{AbstractReactantArrayStyle{N}}, ::Type{<:TracedRNumber{T}}, dims
483+
) where {T<:ReactantPrimitives,N}
462484
@assert N isa Int
463485
return TracedRArray{T,N}((), nothing, map(length, dims))
464486
end
@@ -536,7 +558,7 @@ function broadcast_to_size(arg::T, rsize) where {T<:Number}
536558
end
537559

538560
function broadcast_to_size(arg::TracedRNumber, rsize)
539-
rsize == () && return arg
561+
length(rsize) == 0 && return arg
540562
mlirty = MLIR.IR.type(arg.mlir_data)
541563
return TracedRArray{eltype(arg),length(rsize)}(
542564
(),

src/TracedRNumber.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,17 @@ function promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
5454
),
5555
)
5656
end
57+
if isa(rhs, TracedRArray{<:Any,0})
58+
return TracedRNumber{T}(
59+
(),
60+
MLIR.IR.result(
61+
MLIR.Dialects.stablehlo.convert(
62+
rhs.mlir_data; result=mlir_type(TracedRNumber{T})
63+
),
64+
1,
65+
),
66+
)
67+
end
5768
if isa(rhs, Number)
5869
attr = fill(MLIR.IR.Attribute(T(rhs)), mlir_type(TracedRNumber{T}))
5970
return TracedRNumber{T}(

0 commit comments

Comments
 (0)