|
230 | 230 | @device_override Base.:(^)(x::Float32, y::Float32) = ccall("extern air.pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) |
231 | 231 | @device_override Base.:(^)(x::Float16, y::Float16) = ccall("extern air.pow.f16", llvmcall, Float16, (Float16, Float16), x, y) |
232 | 232 |
|
| 233 | +# Avoid use of Float64 in `pow` |
| 234 | +@device_override @inline function Base.:(^)(x::Float32, y::Integer) |
| 235 | + y == -1 && return inv(x) |
| 236 | + y == 0 && return one(x) |
| 237 | + y == 1 && return x |
| 238 | + y == 2 && return x * x |
| 239 | + y == 3 && return x * x * x |
| 240 | + x^Float32(y) |
| 241 | +end |
| 242 | +@device_override @inline function Base.:(^)(x::Float16, y::Integer) |
| 243 | + y == -1 && return inv(x) |
| 244 | + y == 0 && return one(x) |
| 245 | + y == 1 && return x |
| 246 | + y == 2 && return x * x |
| 247 | + y == 3 && return x * x * x |
| 248 | + x^Float16(y) |
| 249 | +end |
| 250 | + |
233 | 251 | @device_function powr_fast(x::Float32, y::Float32) = ccall("extern air.fast_powr.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) |
234 | 252 | @device_function powr(x::Float32, y::Float32) = ccall("extern air.powr.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) |
235 | 253 | @device_function powr(x::Float16, y::Float16) = ccall("extern air.powr.f16", llvmcall, Float16, (Float16, Float16), x, y) |
|
0 commit comments