42
42
43
43
Base. getindex (a:: AnyTracedRScalar{T} ) where {T} = a
44
44
45
+ Base. zero (:: AnyTracedRScalar{T} ) where {T} = promote_to (TracedRArray{T, 0 }, zero (T))
46
+ Base. one (:: AnyTracedRScalar{T} ) where {T} = promote_to (TracedRArray{T, 0 }, one (T))
47
+
48
+ function Base. convert (:: Type{<:AnyTracedRScalar{T}} , x:: Number ) where {T}
49
+ return promote_to (TracedRArray{T, 0 }, T (x))
50
+ end
51
+
45
52
function Base. getindex (a:: TracedRArray{T,N} , index:: Vararg{Int,N} ) where {T,N}
46
53
@warn (
47
54
""" Performing scalar indexing on task $(current_task ()) .
@@ -514,12 +521,11 @@ for (jlop, hloop, hlocomp, merge) in (
514
521
(:(Base.:(< )), :compare , " LT" , nothing ),
515
522
)
516
523
@eval begin
517
- function elem_apply (
518
- :: typeof ($ jlop),
519
- @nospecialize (lhs:: TracedRArray{T,N} ),
520
- @nospecialize (rhs:: TracedRArray{T,N} )
521
- ) where {T,N}
522
- return TracedRArray {T,N} (
524
+ function $ (jlop)(
525
+ @nospecialize (lhs:: TracedRArray{T,0} ),
526
+ @nospecialize (rhs:: TracedRArray{T,0} )
527
+ ) where {T}
528
+ return TracedRArray {Bool,0} (
523
529
(),
524
530
MLIR. IR. result (
525
531
MLIR. Dialects. stablehlo.$ hloop (
@@ -535,50 +541,26 @@ for (jlop, hloop, hlocomp, merge) in (
535
541
)
536
542
end
537
543
538
- function elem_apply (
539
- fn:: typeof ($ jlop), @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs)
540
- ) where {T,N}
541
- return elem_apply (fn, lhs, promote_to (lhs, rhs))
542
- end
543
-
544
- function elem_apply (
545
- :: typeof ($ jlop), @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,N} )
546
- ) where {T,N}
547
- return elem_apply (fn, promote_to (rhs, lhs), rhs)
548
- end
549
-
550
- function $jlop (
551
- @nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs)
552
- ) where {T,N}
553
- return $ jlop (lhs, promote_to (lhs, rhs))
544
+ function $ (jlop)(
545
+ @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs)
546
+ ) where {T}
547
+ return $ (jlop)(lhs, promote_to (lhs, rhs))
554
548
end
555
549
556
- function $jlop (
557
- @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,N } )
558
- ) where {T,N }
559
- return $ jlop (promote_to (rhs, lhs), rhs)
550
+ function $ ( jlop) (
551
+ @nospecialize (lhs), @nospecialize (rhs:: TracedRArray{T,0 } )
552
+ ) where {T}
553
+ return $ ( jlop) (promote_to (rhs, lhs), rhs)
560
554
end
561
555
end
562
556
563
- if merge != nothing
557
+ if merge != = nothing
564
558
@eval begin
565
559
function $jlop (
566
560
@nospecialize (lhs:: TracedRArray{T,N} ), @nospecialize (rhs:: TracedRArray{T,N} )
567
561
) where {T,N}
568
- elems = elem_apply ($ jlop, lhs, rhs)
569
- if N == 0
570
- elems
571
- else
572
- $ merge (elems)
573
- end
574
- end
575
- end
576
- else
577
- @eval begin
578
- function $jlop (
579
- @nospecialize (lhs:: TracedRArray{T,0} ), @nospecialize (rhs:: TracedRArray{T,0} )
580
- ) where {T}
581
- return elem_apply ($ jlop, lhs, rhs)
562
+ elems = $ (jlop). (lhs, rhs)
563
+ return N == 0 ? elems : $ (merge)(elems)
582
564
end
583
565
end
584
566
end
@@ -644,7 +626,7 @@ function Base.mapreduce(
644
626
645
627
init = [broadcast_to_size (init, ()). mlir_data]
646
628
647
- inp = [elem_apply (f, A). mlir_data]
629
+ inp = [broadcast (f, A). mlir_data]
648
630
649
631
rdims = if dims == (:)
650
632
Int64[i for i in 0 : (N - 1 )]
@@ -706,7 +688,7 @@ function Base.mapreducedim!(
706
688
A:: Base.AbstractArrayOrBroadcasted ,
707
689
)
708
690
tmp = broadcast_to_size (Base. mapreduce (f, op, A; dims= 1 ), (1 , size (R)[2 : end ]. .. ))
709
- R. mlir_data = elem_apply (op, R, tmp). mlir_data
691
+ R. mlir_data = broadcast (op, R, tmp). mlir_data
710
692
return R
711
693
end
712
694
0 commit comments