Skip to content

Commit 72430cd

Browse files
committed
clamp & sign
1 parent cf9fe18 commit 72430cd

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

src/device/intrinsics/math.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,14 @@ using Base.Math: throw_complex_domainerror
88
# - add support for vector types
99
# - consider emitting LLVM intrinsics and lowering those in the back-end
1010

11+
### Common Intrinsics
12+
@device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)
13+
@device_override Base.clamp(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)
14+
@device_override Base.clamp(x::Float16, minval::Float16, maxval::Float16) = ccall("extern air.clamp.f16", llvmcall, Float16, (Float16, Float16, Float16), x, minval, maxval)
15+
16+
@device_override Base.sign(x::Float32) = ccall("extern air.sign.f32", llvmcall, Cfloat, (Cfloat,), x)
17+
@device_override Base.sign(x::Float16) = ccall("extern air.sign.f16", llvmcall, Float16, (Float16,), x)
18+
1119
### Floating Point Intrinsics
1220

1321
## Metal only supports single and half-precision floating-point types (and their vector counterparts)

test/device/intrinsics.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ end
107107
MATH_INTR_FUNCS_1_ARG = [
108108
# Common functions
109109
# saturate, # T saturate(T x) Clamp between 0.0 and 1.0
110-
# sign, # T sign(T x) returns 0.0 if x is NaN. Not tested because intrinsic not yet defined
110+
sign, # T sign(T x) returns 0.0 if x is NaN
111111

112112
# float math
113113
acos, # T acos(T x)
@@ -166,7 +166,6 @@ MATH_INTR_FUNCS_2_ARG = [
166166

167167
MATH_INTR_FUNCS_3_ARG = [
168168
# Common functions
169-
# clamp, # T clamp(T x, T minval, T maxval). Not tested because intrinsic not yet defined
170169
# mix, # T mix(T x, T y, T a) # x+(y-x)*a
171170
# smoothstep, # T smoothstep(T edge0, T edge1, T x)
172171
fma, # T fma(T a, T b, T c)
@@ -268,6 +267,27 @@ end
268267
end
269268
end
270269

270+
let # clamp
271+
N = 4
272+
in = randn(T, N)
273+
minval = fill(T(-0.6), N)
274+
maxval = fill(T(0.6), N)
275+
276+
mtlin = MtlArray(in)
277+
mtlminval = MtlArray(minval)
278+
mtlmaxval = MtlArray(maxval)
279+
280+
mtlout = fill!(similar(mtlin), 0)
281+
282+
function kernel(res, x, y, z)
283+
idx = thread_position_in_grid_1d()
284+
res[idx] = clamp(x[idx], y[idx], z[idx])
285+
return nothing
286+
end
287+
Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
288+
@test Array(mtlout) == clamp.(in, minval, maxval)
289+
end
290+
271291
let #pow
272292
N = 4
273293
arr1 = rand(T, N)

0 commit comments

Comments
 (0)