|
19 | 19 |
|
20 | 20 | TracedRArray{T,N}(x::TracedRArray{T,N}) where {T,N} = x
|
21 | 21 |
|
22 |
| -mutable struct TracedRNumber{T} <: RNumber{T} |
| 22 | +const ReactantPrimitives = Union{ |
| 23 | + Bool, |
| 24 | + Int8, |
| 25 | + UInt8, |
| 26 | + Int16, |
| 27 | + UInt16, |
| 28 | + Int32, |
| 29 | + UInt32, |
| 30 | + Int64, |
| 31 | + UInt64, |
| 32 | + Float16, |
| 33 | + Float32, |
| 34 | + # BFloat16, |
| 35 | + Float64, |
| 36 | + Complex{Float32}, |
| 37 | + Complex{Float64}, |
| 38 | +} |
| 39 | + |
| 40 | +# `<: ReactantPrimitives` ensures we don't end up with nested `TracedRNumber`s |
| 41 | +mutable struct TracedRNumber{T<:ReactantPrimitives} <: RNumber{T} |
23 | 42 | paths::Tuple
|
24 | 43 | mlir_data::Union{Nothing,MLIR.IR.Value}
|
25 | 44 |
|
@@ -214,14 +233,8 @@ function Base.transpose(A::AnyTracedRVecOrMat)
|
214 | 233 | end
|
215 | 234 | Base.adjoint(A::AnyTracedRVecOrMat{<:Real}) = transpose(A)
|
216 | 235 |
|
217 |
| -function Base.promote_rule( |
218 |
| - ::Type{TracedRArray{T,N}}, ::Type{TracedRArray{S,N}} |
219 |
| -) where {T,S,N} |
220 |
| - return TracedRArray{Base.promote_type(T, S),N} |
221 |
| -end |
222 |
| - |
223 |
| -function Base.promote_rule(::Type{T}, ::Type{TracedRArray{S,N}}) where {T,S,N} |
224 |
| - return TracedRArray{Base.promote_type(T, S),N} |
| 236 | +function Base.promote_rule(::Type{TracedRNumber{T}}, ::Type{TracedRNumber{S}}) where {T,S} |
| 237 | + return TracedRNumber{Base.promote_type(T, S)} |
225 | 238 | end
|
226 | 239 |
|
227 | 240 | function Base.promote_rule(::Type{T}, ::Type{TracedRNumber{S}}) where {T,S}
|
@@ -326,8 +339,6 @@ function Base.ifelse(
|
326 | 339 | )
|
327 | 340 | end
|
328 | 341 |
|
329 |
| -Base.abs2(x::Reactant.TracedRNumber{T}) where {T} = x * conj(x) |
330 |
| - |
331 | 342 | function Base.literal_pow(
|
332 | 343 | ::Base.RefValue{typeof(^)}, x::TracedRNumber{T}, ::Base.RefValue{Val{P}}
|
333 | 344 | ) where {T,P}
|
|
355 | 366 |
|
356 | 367 | struct TypeCast{T<:Number} <: Function end
|
357 | 368 |
|
358 |
| -function (::TypeCast{T})(x::TracedRArray{T2,0}) where {T,T2} |
359 |
| - return promote_to(TracedRArray{T,0}, x) |
| 369 | +function (::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} |
| 370 | + return promote_to(TracedRNumber{T}, x) |
360 | 371 | end
|
361 | 372 |
|
362 | 373 | elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x
|
@@ -556,8 +567,7 @@ function Base.mapreduce(
|
556 | 567 | fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys])
|
557 | 568 |
|
558 | 569 | args = (
|
559 |
| - TracedRNumber{T}((), MLIR.IR.argument(fnbody, i), ()) for |
560 |
| - (i, ty) in enumerate(in_tys) |
| 570 | + TracedRNumber{T}((), MLIR.IR.argument(fnbody, i)) for (i, ty) in enumerate(in_tys) |
561 | 571 | )
|
562 | 572 |
|
563 | 573 | res = MLIR.IR.block!(fnbody) do
|
@@ -708,6 +718,25 @@ function broadcast_to_size(arg::T, rsize) where {T<:Number}
|
708 | 718 | )
|
709 | 719 | end
|
710 | 720 |
|
| 721 | +function broadcast_to_size(arg::TracedRNumber, rsize) |
| 722 | + rsize == () && return arg |
| 723 | + mlirty = MLIR.IR.type(arg.mlir_data) |
| 724 | + return TracedRArray{eltype(arg),length(rsize)}( |
| 725 | + (), |
| 726 | + MLIR.IR.result( |
| 727 | + MLIR.Dialects.stablehlo.broadcast_in_dim( |
| 728 | + arg.mlir_data; |
| 729 | + result_0=MLIR.IR.TensorType([t for t in rsize], eltype(mlirty)), |
| 730 | + broadcast_dimensions=MLIR.IR.DenseArrayAttribute([ |
| 731 | + Int64(i - 1) for i in rsize |
| 732 | + ]), |
| 733 | + ), |
| 734 | + 1, |
| 735 | + ), |
| 736 | + rsize, |
| 737 | + ) |
| 738 | +end |
| 739 | + |
711 | 740 | function broadcast_to_size(arg::AnyTracedRArray, rsize)
|
712 | 741 | arg = materialize_traced_array(arg)
|
713 | 742 | size(arg) == rsize && return arg
|
|
0 commit comments