Skip to content

Commit 535b125

Browse files
committed
fix: promote_rule and introduce union over primitive types
1 parent 3d34693 commit 535b125

File tree

1 file changed

+44
-15
lines changed

1 file changed

+44
-15
lines changed

src/TracedRArray.jl

Lines changed: 44 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,26 @@ end
1919

2020
TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x
2121

22-
mutable struct TracedRNumber{T} <: RNumber{T}
22+
const ReactantPrimitives = Union{
23+
Bool,
24+
Int8,
25+
UInt8,
26+
Int16,
27+
UInt16,
28+
Int32,
29+
UInt32,
30+
Int64,
31+
UInt64,
32+
Float16,
33+
Float32,
34+
# BFloat16,
35+
Float64,
36+
Complex{Float32},
37+
Complex{Float64},
38+
}
39+
40+
# `<: ReactantPrimitives` ensures we don't end up with nested `TracedRNumber`s
41+
mutable struct TracedRNumber{T<:ReactantPrimitives} <: RNumber{T}
2342
paths::Tuple
2443
mlir_data::Union{Nothing,MLIR.IR.Value}
2544

@@ -214,14 +233,8 @@ function Base.transpose(A::AnyTracedRVecOrMat)
214233
end
215234
Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A)
216235

217-
function Base.promote_rule(
218-
::Type{TracedRArray{T,N}}, ::Type{TracedRArray{S,N}}
219-
) where {T,S,N}
220-
return TracedRArray{Base.promote_type(T, S),N}
221-
end
222-
223-
function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N}
224-
return TracedRArray{Base.promote_type(T, S),N}
236+
function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S}
237+
return TracedRNumber{Base.promote_type(T, S)}
225238
end
226239

227240
function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S}
@@ -326,8 +339,6 @@ function Base.ifelse(
326339
)
327340
end
328341

329-
Base.abs2(x::Reactant.TracedRNumber{T}) where {T} = x * conj(x)
330-
331342
function Base.literal_pow(
332343
::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}}
333344
) where {T,P}
@@ -355,8 +366,8 @@ end
355366

356367
struct TypeCast{T<:Number} <: Function end
357368

358-
function (::TypeCast{T})(x::TracedRArray{T2,0}) where {T,T2}
359-
return promote_to(TracedRArray{T,0}, x)
369+
function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2}
370+
return promote_to(TracedRNumber{T}, x)
360371
end
361372

362373
elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x
@@ -556,8 +567,7 @@ function Base.mapreduce(
556567
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys])
557568

558569
args = (
559-
TracedRNumber{T}((), MLIR.IR.argument(fnbody, i), ()) for
560-
(i, ty) in enumerate(in_tys)
570+
TracedRNumber{T}((), MLIR.IR.argument(fnbody, i)) for (i, ty) in enumerate(in_tys)
561571
)
562572

563573
res = MLIR.IR.block!(fnbody) do
@@ -708,6 +718,25 @@ function broadcast_to_size(arg::T, rsize) where {T<:Number}
708718
)
709719
end
710720

721+
function broadcast_to_size(arg::TracedRNumber, rsize)
722+
rsize == () && return arg
723+
mlirty = MLIR.IR.type(arg.mlir_data)
724+
return TracedRArray{eltype(arg),length(rsize)}(
725+
(),
726+
MLIR.IR.result(
727+
MLIR.Dialects.stablehlo.broadcast_in_dim(
728+
arg.mlir_data;
729+
result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)),
730+
broadcast_dimensions=MLIR.IR.DenseArrayAttribute([
731+
Int64(i - 1) for i in rsize
732+
]),
733+
),
734+
1,
735+
),
736+
rsize,
737+
)
738+
end
739+
711740
function broadcast_to_size(arg::AnyTracedRArray, rsize)
712741
arg = materialize_traced_array(arg)
713742
size(arg) == rsize && return arg

0 commit comments

Comments
 (0)