@@ -539,6 +539,9 @@ __default_init(::Type{T}, ::typeof(Base.max)) where {T} = typemin(T)
539
539
function __default_init (:: Type{T} , op:: F ) where {T,F}
540
540
return Base. reduce_empty (Base. BottomRF (op), T)
541
541
end
542
+ function __default_init (T:: Type{<:Reactant.ReactantFloat8} , op:: F ) where {F}
543
+ return T (__default_init (Float16, op))
544
+ end
542
545
543
546
function overloaded_mapreduce (
544
547
@nospecialize (f), @nospecialize (op), @nospecialize (A); dims= :, init= nothing
@@ -562,89 +565,33 @@ function overloaded_mapreduce(
562
565
) where {T,N}
563
566
A = materialize_traced_array (A)
564
567
565
- if dims isa Int
566
- dims = [dims]
567
- end
568
-
569
- op_in_T = Core. Compiler. return_type (f, Tuple{T})
570
-
571
- if init === nothing
572
- init = __default_init (op_in_T, op)
573
-
574
- if typeof (init) != op_in_T
575
- op_in_T = typeof (init)
576
- A = typeof (init).(A)
577
- end
578
- end
579
-
580
- init = [TracedUtils. broadcast_to_size (init, ()). mlir_data]
581
-
582
- inp = [broadcast (f, A). mlir_data]
568
+ original_dims = dims
569
+ dims isa Int && (dims = Int64[dims])
570
+ dims isa Colon && (dims = collect (Int64, 1 : N))
571
+ dims isa AbstractVector{<: Integer } || (dims = collect (Int64, dims))
583
572
584
- rdims = Int64[]
585
-
586
- if dims == (:)
587
- for i in 0 : (N - 1 )
588
- push! (rdims, i)
589
- end
590
- else
591
- for i in dims
592
- push! (rdims, i - 1 )
593
- end
573
+ op_in_T = unwrapped_eltype (Core. Compiler. return_type (f, Tuple{T}))
574
+ reduce_init = __default_init (op_in_T, op)
575
+ if unwrapped_eltype (typeof (reduce_init)) != op_in_T
576
+ op_in_T = typeof (reduce_init)
577
+ A = typeof (reduce_init).(A)
594
578
end
579
+ reduce_init = TracedUtils. promote_to (TracedRNumber{op_in_T}, reduce_init)
595
580
596
- in_tys = [
597
- MLIR. IR. TensorType (Int64[], eltype (MLIR. IR. type (inp[1 ]))),
598
- MLIR. IR. TensorType (Int64[], eltype (MLIR. IR. type (init[1 ]))),
599
- ]
581
+ reduce_input = materialize_traced_array (broadcast (f, A))
600
582
601
- fnbody = MLIR . IR . Block (in_tys, [MLIR . IR . Location (), MLIR . IR . Location ()] )
583
+ res = Ops . reduce (reduce_input, reduce_init, dims, op )
602
584
603
- args = (
604
- TracedRNumber {Reactant.unwrapped_eltype(op_in_T)} ((), MLIR. IR. argument (fnbody, 1 )),
605
- TracedRNumber {Reactant.unwrapped_eltype(op_in_T)} ((), MLIR. IR. argument (fnbody, 2 )),
606
- )
607
-
608
- resty = MLIR. IR. block! (fnbody) do
609
- tmp = TracedUtils. broadcast_to_size (op (args... ), ())
610
- Ops. return_ (tmp)
611
- return eltype (MLIR. IR. type (tmp. mlir_data))
612
- end
585
+ init != = nothing && (res = op .(res, init))
613
586
614
- toonedims = Int[]
615
- outdims = Int[]
616
- for i in 1 : N
617
- tmp = if in (i - 1 , rdims)
618
- 1
619
- else
620
- sz = size (A, i)
621
- push! (outdims, sz)
622
- sz
623
- end
624
- push! (toonedims, tmp)
587
+ if original_dims isa Colon
588
+ @assert size (res) == () " expected size of result to be (), got $(size (res)) "
589
+ return TracedRNumber {unwrapped_eltype(res)} ((), res. mlir_data)
625
590
end
626
-
627
- TT = MLIR. IR. Type[MLIR. IR. TensorType (outdims, resty)]
628
-
629
- body = MLIR. IR. Region ()
630
- push! (body, fnbody)
631
- red = MLIR. Dialects. stablehlo. reduce (
632
- inp, init; result_0= TT, dimensions= MLIR. IR. DenseArrayAttribute (rdims), body
633
- )
634
-
635
- red = MLIR. IR. result (red, 1 )
636
- redT = eltype (MLIR. IR. julia_type (MLIR. IR. type (red)))
637
-
638
- if dims != (:)
639
- red = Ops. reshape (TracedRArray (red), toonedims... )
640
- else
641
- if length (outdims) == 0
642
- red = TracedRNumber {redT} ((), red)
643
- else
644
- red = TracedRArray {redT,length(outdims)} ((), red, (outdims... ,))
645
- end
591
+ if res isa TracedRNumber
592
+ res = TracedRArray {unwrapped_eltype(res),0} ((), res. mlir_data, ())
646
593
end
647
- return red
594
+ return Ops . reshape (res, [ ifelse (i in dims, 1 , size (A, i)) for i in 1 : N])
648
595
end
649
596
650
597
function Base. mapreducedim! (
@@ -1343,6 +1290,11 @@ function scan_impl!(
1343
1290
op_in_T = typeof (init)
1344
1291
input = typeof (init).(input)
1345
1292
end
1293
+ else
1294
+ # TODO : fix this for TPUs
1295
+ if contains (string (first (Reactant. devices ())), " TPU" )
1296
+ throw (AssertionError (" Currently, `init` is not supported on TPUs." ))
1297
+ end
1346
1298
end
1347
1299
init = something (init) # unwrap Some
1348
1300
init = TracedUtils. promote_to (TracedRNumber{unwrapped_eltype (init)}, init)
0 commit comments