Skip to content

Commit ab25f7b

Browse files
Fix erf and a few other improvements (#582)
* Fix `erf` * [NFC] Semicolon cleanup * Test proper type
1 parent a5b56dc commit ab25f7b

File tree

3 files changed

+24
-19
lines changed

3 files changed

+24
-19
lines changed

ext/SpecialFunctionsExt.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,14 @@ Metal.@device_override function SpecialFunctions.erf(x::Float32)
7171
if ix < 0x04000000 # |x|<0x1p-119
7272
return (8 * x + efx8 * x) / 8 # avoid spurious underflow
7373
end
74-
return x + efx*x;
74+
return x + efx*x
7575
end
7676
end
77+
z = x * x
78+
r = pp0 + z * (pp1 + z * pp2)
79+
s = 1.0f0 + z * (qq1 + z * (qq2 + z * qq3))
80+
y = r / s
81+
return x + x*y
7782
end
7883

7984
if ix < 0x3fa00000 # 0.84375 <= |x| < 1.25
@@ -152,15 +157,15 @@ Metal.@device_override function SpecialFunctions.erfc(x::Float32)
152157
Q = 1.0f0 + s * (qa1 + s * (qa2 + s * (qa3 + s * qa4)))
153158
if hx >= 0
154159
z = 1.0f0 - erx
155-
return z - P / Q;
160+
return z - P / Q
156161
else
157162
z = erx + P / Q
158163
return 1.0f0 + z
159164
end
160165
end
161166

162167
if ix < 0x41300000 # |x|<28
163-
x = abs(x);
168+
x = abs(x)
164169
s = 1.0f0 / (x * x)
165170
if ix < 0x4036DB6D # |x| < 1/.35 ~ 2.857143
166171
R = ra0 + s * (ra1 + s * (ra2 + s * ra3))

test/device/intrinsics/math.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -237,38 +237,38 @@ end
237237
end
238238

239239
let # log1p
240-
arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
240+
arr = T.(collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)))
241241
buffer = MtlArray(arr)
242-
vec = Array(log1p.(buffer))
243-
@test vec log1p.(arr)
242+
cpures = log1p.(arr)
243+
@test Array(log1p.(buffer)) log1p.(arr)
244244
end
245245

246246
let # erf
247-
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
247+
arr = T[-1.0, -0.5, 0.0, 1.0e-3, 1.0, 2.0, 5.5]
248248
buffer = MtlArray(arr)
249-
vec = Array(SpecialFunctions.erf.(buffer))
250-
@test vec SpecialFunctions.erf.(arr)
249+
cpures = SpecialFunctions.erf.(arr)
250+
@test Array(SpecialFunctions.erf.(buffer)) cpures broken = (T == Float16)
251251
end
252252

253253
let # erfc
254-
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
254+
arr = T.(collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)))
255255
buffer = MtlArray(arr)
256-
vec = Array(SpecialFunctions.erfc.(buffer))
257-
@test vec SpecialFunctions.erfc.(arr)
256+
cpures = SpecialFunctions.erfc.(arr)
257+
@test Array(SpecialFunctions.erfc.(buffer)) cpures broken = (T == Float16)
258258
end
259259

260260
let # erfinv
261-
arr = collect(LinRange(-1.0f0, 1.0f0, 20))
261+
arr = T.(collect(LinRange(-1.0f0, 1.0f0, 20)))
262262
buffer = MtlArray(arr)
263-
vec = Array(SpecialFunctions.erfinv.(buffer))
264-
@test vec SpecialFunctions.erfinv.(arr)
263+
cpures = SpecialFunctions.erfinv.(arr)
264+
@test Array(SpecialFunctions.erfinv.(buffer)) cpures
265265
end
266266

267267
let # expm1
268-
arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
268+
arr = T.(collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)))
269269
buffer = MtlArray(arr)
270-
vec = Array(expm1.(buffer))
271-
@test vec expm1.(arr)
270+
cpures = expm1.(arr)
271+
@test Array(expm1.(buffer)) cpures
272272
end
273273

274274

test/mps/ndarray.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#
22
# matrix descriptor
33
#
4-
using Metal,Test;
4+
using Metal
55
using .MPS: MPSNDArrayDescriptor, MPSDataType, lengthOfDimension, descriptor, resourceSize
66
@static if Metal.macos_version() >= v"15"
77
using .MPS: userBuffer

0 commit comments

Comments
 (0)