Skip to content

Commit c9ca274

Browse files
committed
fix: broadcasting of comparison
1 parent d85aca5 commit c9ca274

File tree

1 file changed

+25
-43
lines changed

1 file changed

+25
-43
lines changed

src/TracedRArray.jl

Lines changed: 25 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,13 @@ end
4242

4343
Base.getindex(a::AnyTracedRScalar{T}) where {T} = a
4444

45+
Base.zero(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T, 0}, zero(T))
46+
Base.one(::AnyTracedRScalar{T}) where {T} = promote_to(TracedRArray{T, 0}, one(T))
47+
48+
function Base.convert(::Type{<:AnyTracedRScalar{T}}, x::Number) where {T}
49+
return promote_to(TracedRArray{T, 0}, T(x))
50+
end
51+
4552
function Base.getindex(a::TracedRArray{T,N}, index::Vararg{Int,N}) where {T,N}
4653
@warn(
4754
"""Performing scalar indexing on task $(current_task()).
@@ -514,12 +521,11 @@ for (jlop, hloop, hlocomp, merge) in (
514521
(:(Base.:(<)), :compare, "LT", nothing),
515522
)
516523
@eval begin
517-
function elem_apply(
518-
::typeof($jlop),
519-
@nospecialize(lhs::TracedRArray{T,N}),
520-
@nospecialize(rhs::TracedRArray{T,N})
521-
) where {T,N}
522-
return TracedRArray{T,N}(
524+
function $(jlop)(
525+
@nospecialize(lhs::TracedRArray{T,0}),
526+
@nospecialize(rhs::TracedRArray{T,0})
527+
) where {T}
528+
return TracedRArray{Bool,0}(
523529
(),
524530
MLIR.IR.result(
525531
MLIR.Dialects.stablehlo.$hloop(
@@ -535,50 +541,26 @@ for (jlop, hloop, hlocomp, merge) in (
535541
)
536542
end
537543

538-
function elem_apply(
539-
fn::typeof($jlop), @nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs)
540-
) where {T,N}
541-
return elem_apply(fn, lhs, promote_to(lhs, rhs))
542-
end
543-
544-
function elem_apply(
545-
::typeof($jlop), @nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,N})
546-
) where {T,N}
547-
return elem_apply(fn, promote_to(rhs, lhs), rhs)
548-
end
549-
550-
function $jlop(
551-
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs)
552-
) where {T,N}
553-
return $jlop(lhs, promote_to(lhs, rhs))
544+
function $(jlop)(
545+
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs)
546+
) where {T}
547+
return $(jlop)(lhs, promote_to(lhs, rhs))
554548
end
555549

556-
function $jlop(
557-
@nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,N})
558-
) where {T,N}
559-
return $jlop(promote_to(rhs, lhs), rhs)
550+
function $(jlop)(
551+
@nospecialize(lhs), @nospecialize(rhs::TracedRArray{T,0})
552+
) where {T}
553+
return $(jlop)(promote_to(rhs, lhs), rhs)
560554
end
561555
end
562556

563-
if merge != nothing
557+
if merge !== nothing
564558
@eval begin
565559
function $jlop(
566560
@nospecialize(lhs::TracedRArray{T,N}), @nospecialize(rhs::TracedRArray{T,N})
567561
) where {T,N}
568-
elems = elem_apply($jlop, lhs, rhs)
569-
if N == 0
570-
elems
571-
else
572-
$merge(elems)
573-
end
574-
end
575-
end
576-
else
577-
@eval begin
578-
function $jlop(
579-
@nospecialize(lhs::TracedRArray{T,0}), @nospecialize(rhs::TracedRArray{T,0})
580-
) where {T}
581-
return elem_apply($jlop, lhs, rhs)
562+
elems = $(jlop).(lhs, rhs)
563+
return N == 0 ? elems : $(merge)(elems)
582564
end
583565
end
584566
end
@@ -644,7 +626,7 @@ function Base.mapreduce(
644626

645627
init = [broadcast_to_size(init, ()).mlir_data]
646628

647-
inp = [elem_apply(f, A).mlir_data]
629+
inp = [broadcast(f, A).mlir_data]
648630

649631
rdims = if dims == (:)
650632
Int64[i for i in 0:(N - 1)]
@@ -706,7 +688,7 @@ function Base.mapreducedim!(
706688
A::Base.AbstractArrayOrBroadcasted,
707689
)
708690
tmp = broadcast_to_size(Base.mapreduce(f, op, A; dims=1), (1, size(R)[2:end]...))
709-
R.mlir_data = elem_apply(op, R, tmp).mlir_data
691+
R.mlir_data = broadcast(op, R, tmp).mlir_data
710692
return R
711693
end
712694

0 commit comments

Comments
 (0)