Skip to content

Commit 839ff8d

Browse files
authored
refactor: delegate simple opts to mlir/base julia (#1298)
1 parent 2bdedbd commit 839ff8d

File tree

4 files changed

+27
-40
lines changed

4 files changed

+27
-40
lines changed

src/Compiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ function optimization_passes(;
728728
"concat_reshape_reduce",
729729
"concat_elementwise",
730730
"reduce_reduce",
731+
"conj_real",
731732
# TODO we want to enable but may cause an infinite compile time
732733
# "concat_to_onedim_dusslice",
733734
]

src/Ops.jl

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ end
304304

305305
@noinline function conj(
306306
x::TracedRArray{T,N}; location=mlir_stacktrace("conj", @__FILE__, @__LINE__)
307-
) where {T<:Complex,N}
307+
) where {T,N}
308308
res = MLIR.IR.result(
309309
chlo.conj(x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location)
310310
)
@@ -313,7 +313,7 @@ end
313313

314314
@noinline function conj(
315315
x::TracedRNumber{T}; location=mlir_stacktrace("conj", @__FILE__, @__LINE__)
316-
) where {T<:Complex}
316+
) where {T}
317317
res = MLIR.IR.result(
318318
chlo.conj(x.mlir_data; result=mlir_type(TracedRArray{T,0}, ()), location)
319319
)
@@ -597,39 +597,47 @@ end
597597
end
598598

599599
@noinline function real(
600-
x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("real", @__FILE__, @__LINE__)
600+
x::TracedRArray{T,N}; location=mlir_stacktrace("real", @__FILE__, @__LINE__)
601601
) where {T,N}
602602
res = MLIR.IR.result(
603-
stablehlo.real(x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location)
603+
stablehlo.real(
604+
x.mlir_data; result=mlir_type(TracedRArray{Base.real(T),N}, size(x)), location
605+
),
604606
)
605-
return TracedRArray{T,N}((), res, size(x))
607+
return TracedRArray{Base.real(T),N}((), res, size(x))
606608
end
607609

608610
@noinline function real(
609-
x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("real", @__FILE__, @__LINE__)
611+
x::TracedRNumber{T}; location=mlir_stacktrace("real", @__FILE__, @__LINE__)
610612
) where {T}
611613
res = MLIR.IR.result(
612-
stablehlo.real(x.mlir_data; result=mlir_type(TracedRArray{T,0}, ()), location)
614+
stablehlo.real(
615+
x.mlir_data; result=mlir_type(TracedRArray{Base.real(T),0}, ()), location
616+
),
613617
)
614-
return TracedRNumber{T}((), res)
618+
return TracedRNumber{Base.real(T)}((), res)
615619
end
616620

617621
@noinline function imag(
618-
x::TracedRArray{Complex{T},N}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__)
622+
x::TracedRArray{T,N}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__)
619623
) where {T,N}
620624
res = MLIR.IR.result(
621-
stablehlo.imag(x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location)
625+
stablehlo.imag(
626+
x.mlir_data; result=mlir_type(TracedRArray{Base.real(T),N}, size(x)), location
627+
),
622628
)
623-
return TracedRArray{T,N}((), res, size(x))
629+
return TracedRArray{Base.real(T),N}((), res, size(x))
624630
end
625631

626632
@noinline function imag(
627-
x::TracedRNumber{Complex{T}}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__)
633+
x::TracedRNumber{T}; location=mlir_stacktrace("imag", @__FILE__, @__LINE__)
628634
) where {T}
629635
res = MLIR.IR.result(
630-
stablehlo.imag(x.mlir_data; result=mlir_type(TracedRArray{T,0}, ()), location)
636+
stablehlo.imag(
637+
x.mlir_data; result=mlir_type(TracedRArray{Base.real(T),0}, ()), location
638+
),
631639
)
632-
return TracedRNumber{T}((), res)
640+
return TracedRNumber{Base.real(T)}((), res)
633641
end
634642

635643
function bitcast_convert(
@@ -679,7 +687,7 @@ end
679687
end
680688
elseif type == "IRFFT"
681689
@assert T <: Complex
682-
Tout = Base.real(T)
690+
Tout = Base.Base.real(T)
683691
rsize = let rsize = collect(Int64, size(x))
684692
rsize[(end - Base.length(length) + 1):end] = length
685693
Tuple(rsize)

src/TracedRArray.jl

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -511,22 +511,6 @@ function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
511511
return Ops.transpose(materialize_traced_array(A), Int64[perm...])
512512
end
513513

514-
Base.conj(A::AnyTracedRArray) = A
515-
Base.conj(A::AnyTracedRArray{<:Complex}) = Ops.conj(materialize_traced_array(A))
516-
517-
Base.conj!(A::AnyTracedRArray) = A
518-
519-
function Base.conj!(A::AnyTracedRArray{<:Complex})
520-
TracedUtils.set_mlir_data!(A, Ops.conj(materialize_traced_array(A)).mlir_data)
521-
return A
522-
end
523-
524-
Base.real(A::AnyTracedRArray) = A
525-
Base.real(A::AnyTracedRArray{<:Complex}) = Ops.real(materialize_traced_array(A))
526-
527-
Base.imag(A::AnyTracedRArray) = zero(A)
528-
Base.imag(A::AnyTracedRArray{<:Complex}) = Ops.imag(materialize_traced_array(A))
529-
530514
TracedUtils.promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs)
531515
function TracedUtils.promote_to(::TracedRArray{T,N}, rhs) where {T,N}
532516
return TracedUtils.promote_to(TracedRArray{T,N}, rhs)

src/TracedRNumber.jl

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,9 @@ for (jlop, hloop) in (
393393
(:(Base.atan), :atan),
394394
(:(Base.atanh), :atanh),
395395
(:(Base.sign), :sign),
396+
(:(Base.conj), :conj),
397+
(:(Base.real), :real),
398+
(:(Base.imag), :imag),
396399
)
397400
@eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = Ops.$(hloop)(lhs)
398401
end
@@ -404,18 +407,9 @@ end
404407

405408
Base.sincospi(x::TracedRNumber{T}) where {T} = Ops.sine(T(π) * x), Ops.cosine(T(π) * x)
406409

407-
Base.conj(x::TracedRNumber) = x
408-
Base.conj(x::TracedRNumber{<:Complex}) = Ops.conj(x)
409-
410-
Base.real(x::TracedRNumber) = x
411-
Base.real(x::TracedRNumber{<:Complex}) = Ops.real(x)
412-
413410
Base.isreal(::TracedRNumber) = false
414411
Base.isreal(::TracedRNumber{<:Real}) = true
415412

416-
Base.imag(x::TracedRNumber) = zero(x)
417-
Base.imag(x::TracedRNumber{<:Complex}) = Ops.imag(x)
418-
419413
Base.iseven(x::TracedRNumber) = iseven(real(x))
420414
function Base.iseven(x::TracedRNumber{<:Real})
421415
return iszero(

0 commit comments

Comments
 (0)