Skip to content

Commit d8a1c23

Browse files
committed
add new overdub for unsafe_getindex to avoid allocating error message
1 parent ace8891 commit d8a1c23

File tree

2 files changed

+31
-19
lines changed

2 files changed

+31
-19
lines changed

lib/CUDAKernels/src/CUDAKernels.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -309,23 +309,6 @@ end
309309
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erf), x::Union{Float32, Float64}) = CUDA.erf(x)
310310
@inline Cassette.overdub(::CUDACtx, ::typeof(SpecialFunctions.erfc), x::Union{Float32, Float64}) = CUDA.erfc(x)
311311

312-
@inline function Cassette.overdub(::CUDACtx, ::typeof(exponent), x::Union{Float32, Float64})
313-
T = typeof(x)
314-
xs = reinterpret(Unsigned, x) & ~Base.sign_mask(T)
315-
if xs >= Base.exponent_mask(T)
316-
throw(DomainError(x, "Cannot be Nan of Inf."))
317-
end
318-
k = Int(xs >> Base.significand_bits(T))
319-
if k == 0 # x is subnormal
320-
if xs == 0
321-
throw(DomainError(x, "Cannot be subnormal converted to 0."))
322-
end
323-
m = Base.leading_zeros(xs) - Base.exponent_bits(T)
324-
k = 1 - m
325-
end
326-
return k - Base.exponent_bias(T)
327-
end
328-
329312
@static if Base.isbindingresolved(CUDA, :emit_shmem) && Base.isdefined(CUDA, :emit_shmem)
330313
const emit_shmem = CUDA.emit_shmem
331314
else

src/compiler.jl

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ function generate_overdubs(mod, Ctx)
4646
@inline Cassette.overdub(::$Ctx, ::typeof(-), a::T, b::T) where T<:Union{Float32, Float64} = sub_float_contract(a, b)
4747
@inline Cassette.overdub(::$Ctx, ::typeof(*), a::T, b::T) where T<:Union{Float32, Float64} = mul_float_contract(a, b)
4848

49+
@inline Cassette.overdub(::$Ctx, ::typeof(Base.literal_pow), f::F, x, p) where F = Base.literal_pow(f, x, p)
50+
4951
function Cassette.overdub(::$Ctx, ::typeof(:), start::T, step::T, stop::T) where T<:Union{Float16,Float32,Float64}
5052
lf = (stop-start)/step
5153
if lf < 0
@@ -61,15 +63,42 @@ function generate_overdubs(mod, Ctx)
6163
Base.steprangelen_hp(T, start, step, 0, len, 1)
6264
end
6365

64-
@inline Cassette.overdub(::$Ctx, ::typeof(Base.literal_pow), f::F, x, p) where F = Base.literal_pow(f, x, p)
65-
6666
if VERSION >= v"1.5"
6767
@inline function Cassette.overdub(::$Ctx, ::typeof(Base.Checked.throw_overflowerr_binaryop), op, x, y)
6868
throw(OverflowError("checked arithmetic: cannot compute"))
6969
end
70+
7071
@inline function Cassette.overdub(::$Ctx, ::typeof(Base.Checked.throw_overflowerr_negation), x)
7172
throw(OverflowError("checked arithmetic: cannot compute -x"))
7273
end
74+
75+
@inline function Cassette.overdub(::$Ctx, ::typeof(exponent), x::Union{Float32, Float64})
76+
T = typeof(x)
77+
xs = reinterpret(Unsigned, x) & ~Base.sign_mask(T)
78+
if xs >= Base.exponent_mask(T)
79+
throw(DomainError(x, "Cannot be Nan of Inf."))
80+
end
81+
k = Int(xs >> Base.significand_bits(T))
82+
if k == 0 # x is subnormal
83+
if xs == 0
84+
throw(DomainError(x, "Cannot be subnormal converted to 0."))
85+
end
86+
m = Base.leading_zeros(xs) - Base.exponent_bits(T)
87+
k = 1 - m
88+
end
89+
return k - Base.exponent_bias(T)
90+
end
91+
end
92+
93+
if VERSION >= v"1.6"
94+
@inline function Cassette.overdub(::$Ctx, ::typeof(Base._unsafe_getindex),
95+
::IndexStyle, A::AbstractArray, I::Vararg{Union{Real, AbstractArray}, N}) where N
96+
shape = index_shape(I...)
97+
dest = similar(A, shape)
98+
map(unsafe_length, axes(dest)) == map(unsafe_length, shape) || throw(DimensionMismatch("output array is the wrong size"))
99+
_unsafe_getindex!(dest, A, I...)
100+
return dest
101+
end
73102
end
74103
end
75104
end

0 commit comments

Comments
 (0)