Skip to content

Commit 2061166

Browse files
authored
fix: init for julia and shlo are semantically different (#1532)
* fix: init for julia and shlo are semantically different * fix: return type * fix: throw error for TPU
1 parent 60798b3 commit 2061166

File tree

1 file changed

+27
-75
lines changed

1 file changed

+27
-75
lines changed

src/TracedRArray.jl

Lines changed: 27 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,9 @@ __default_init(::Type{T}, ::typeof(Base.max)) where {T} = typemin(T)
539539
function __default_init(::Type{T}, op::F) where {T,F}
540540
return Base.reduce_empty(Base.BottomRF(op), T)
541541
end
542+
function __default_init(T::Type{<:Reactant.ReactantFloat8}, op::F) where {F}
543+
return T(__default_init(Float16, op))
544+
end
542545

543546
function overloaded_mapreduce(
544547
@nospecialize(f), @nospecialize(op), @nospecialize(A); dims=:, init=nothing
@@ -562,89 +565,33 @@ function overloaded_mapreduce(
562565
) where {T,N}
563566
A = materialize_traced_array(A)
564567

565-
if dims isa Int
566-
dims = [dims]
567-
end
568-
569-
op_in_T = Core.Compiler.return_type(f, Tuple{T})
570-
571-
if init === nothing
572-
init = __default_init(op_in_T, op)
573-
574-
if typeof(init) != op_in_T
575-
op_in_T = typeof(init)
576-
A = typeof(init).(A)
577-
end
578-
end
579-
580-
init = [TracedUtils.broadcast_to_size(init, ()).mlir_data]
581-
582-
inp = [broadcast(f, A).mlir_data]
568+
original_dims = dims
569+
dims isa Int && (dims = Int64[dims])
570+
dims isa Colon && (dims = collect(Int64, 1:N))
571+
dims isa AbstractVector{<:Integer} || (dims = collect(Int64, dims))
583572

584-
rdims = Int64[]
585-
586-
if dims == (:)
587-
for i in 0:(N - 1)
588-
push!(rdims, i)
589-
end
590-
else
591-
for i in dims
592-
push!(rdims, i - 1)
593-
end
573+
op_in_T = unwrapped_eltype(Core.Compiler.return_type(f, Tuple{T}))
574+
reduce_init = __default_init(op_in_T, op)
575+
if unwrapped_eltype(typeof(reduce_init)) != op_in_T
576+
op_in_T = typeof(reduce_init)
577+
A = typeof(reduce_init).(A)
594578
end
579+
reduce_init = TracedUtils.promote_to(TracedRNumber{op_in_T}, reduce_init)
595580

596-
in_tys = [
597-
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(inp[1]))),
598-
MLIR.IR.TensorType(Int64[], eltype(MLIR.IR.type(init[1]))),
599-
]
581+
reduce_input = materialize_traced_array(broadcast(f, A))
600582

601-
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location(), MLIR.IR.Location()])
583+
res = Ops.reduce(reduce_input, reduce_init, dims, op)
602584

603-
args = (
604-
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 1)),
605-
TracedRNumber{Reactant.unwrapped_eltype(op_in_T)}((), MLIR.IR.argument(fnbody, 2)),
606-
)
607-
608-
resty = MLIR.IR.block!(fnbody) do
609-
tmp = TracedUtils.broadcast_to_size(op(args...), ())
610-
Ops.return_(tmp)
611-
return eltype(MLIR.IR.type(tmp.mlir_data))
612-
end
585+
init !== nothing && (res = op.(res, init))
613586

614-
toonedims = Int[]
615-
outdims = Int[]
616-
for i in 1:N
617-
tmp = if in(i - 1, rdims)
618-
1
619-
else
620-
sz = size(A, i)
621-
push!(outdims, sz)
622-
sz
623-
end
624-
push!(toonedims, tmp)
587+
if original_dims isa Colon
588+
@assert size(res) == () "expected size of result to be (), got $(size(res))"
589+
return TracedRNumber{unwrapped_eltype(res)}((), res.mlir_data)
625590
end
626-
627-
TT = MLIR.IR.Type[MLIR.IR.TensorType(outdims, resty)]
628-
629-
body = MLIR.IR.Region()
630-
push!(body, fnbody)
631-
red = MLIR.Dialects.stablehlo.reduce(
632-
inp, init; result_0=TT, dimensions=MLIR.IR.DenseArrayAttribute(rdims), body
633-
)
634-
635-
red = MLIR.IR.result(red, 1)
636-
redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red)))
637-
638-
if dims != (:)
639-
red = Ops.reshape(TracedRArray(red), toonedims...)
640-
else
641-
if length(outdims) == 0
642-
red = TracedRNumber{redT}((), red)
643-
else
644-
red = TracedRArray{redT,length(outdims)}((), red, (outdims...,))
645-
end
591+
if res isa TracedRNumber
592+
res = TracedRArray{unwrapped_eltype(res),0}((), res.mlir_data, ())
646593
end
647-
return red
594+
return Ops.reshape(res, [ifelse(i in dims, 1, size(A, i)) for i in 1:N])
648595
end
649596

650597
function Base.mapreducedim!(
@@ -1343,6 +1290,11 @@ function scan_impl!(
13431290
op_in_T = typeof(init)
13441291
input = typeof(init).(input)
13451292
end
1293+
else
1294+
# TODO: fix this for TPUs
1295+
if contains(string(first(Reactant.devices())), "TPU")
1296+
throw(AssertionError("Currently, `init` is not supported on TPUs."))
1297+
end
13461298
end
13471299
init = something(init) # unwrap Some
13481300
init = TracedUtils.promote_to(TracedRNumber{unwrapped_eltype(init)}, init)

0 commit comments

Comments
 (0)