Skip to content

Commit b1eb6aa

Browse files
Support pow with Int exponent (#557)
1 parent c5b425d commit b1eb6aa

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

src/device/intrinsics/math.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,24 @@ end
230230
@device_override Base.:(^)(x::Float32, y::Float32) = ccall("extern air.pow.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
231231
@device_override Base.:(^)(x::Float16, y::Float16) = ccall("extern air.pow.f16", llvmcall, Float16, (Float16, Float16), x, y)
232232

233+
# Avoid use of Float64 in `pow`
234+
@device_override @inline function Base.:(^)(x::Float32, y::Integer)
235+
y == -1 && return inv(x)
236+
y == 0 && return one(x)
237+
y == 1 && return x
238+
y == 2 && return x * x
239+
y == 3 && return x * x * x
240+
x^Float32(y)
241+
end
242+
@device_override @inline function Base.:(^)(x::Float16, y::Integer)
243+
y == -1 && return inv(x)
244+
y == 0 && return one(x)
245+
y == 1 && return x
246+
y == 2 && return x * x
247+
y == 3 && return x * x * x
248+
x^Float16(y)
249+
end
250+
233251
@device_function powr_fast(x::Float32, y::Float32) = ccall("extern air.fast_powr.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
234252
@device_function powr(x::Float32, y::Float32) = ccall("extern air.powr.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
235253
@device_function powr(x::Float16, y::Float16) = ccall("extern air.powr.f16", llvmcall, Float16, (Float16, Float16), x, y)

test/device/intrinsics.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,24 @@ end
300300
@test Array(mtlout) arr1 .^ arr2
301301
end
302302

303+
let #pow with Integer exponent (Issue 552)
304+
N = 4
305+
arr2 = [-1, 0, 1, 2, 3, rand(-10:10, N)...]
306+
arr1 = rand(T, length(arr2))
307+
mtlarr1 = MtlArray(arr1)
308+
mtlarr2 = MtlArray(arr2)
309+
310+
mtlout = fill!(similar(mtlarr1), 0)
311+
312+
function kernel(res, x, y)
313+
idx = thread_position_in_grid_1d()
314+
res[idx] = x[idx]^y[idx]
315+
return nothing
316+
end
317+
Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr1, mtlarr2)
318+
@test Array(mtlout) arr1 .^ arr2
319+
end
320+
303321
let #powr
304322
N = 4
305323
arr1 = rand(T, N)

0 commit comments

Comments
 (0)