Skip to content

Commit e260219

Browse files
authored
override ifelse (#1577)
* override ifelse * fix
1 parent e7b0a5a commit e260219

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

ext/ReactantCUDAExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@ for jlop in (
8585
end
8686
end
8787

88+
@inline Base.ifelse(cond::Bool, a, b::CuTracedRNumber) = Base.ifelse(cond, a, b[])
89+
@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b) = Base.ifelse(cond, a[], b)
90+
@inline Base.ifelse(cond::Bool, a::CuTracedRNumber, b::CuTracedRNumber) = Base.ifelse(cond, a[], b[])
91+
@inline Base.ifelse(cond::CuTracedRNumber, a, b) = Base.ifelse(cond[], a, b)
92+
@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b) = Base.ifelse(cond[], a[], b)
93+
@inline Base.ifelse(cond::CuTracedRNumber, a, b::CuTracedRNumber) = Base.ifelse(cond[], a, b[])
94+
@inline Base.ifelse(cond::CuTracedRNumber, a::CuTracedRNumber, b::CuTracedRNumber) = Base.ifelse(cond[], a[], b[])
95+
8896
Base.@constprop :aggressive @inline Base.:^(
8997
a::CuTracedRNumber{T,A}, b::Integer
9098
) where {T,A} = ^(a[], b)

0 commit comments

Comments
 (0)