@@ -188,4 +188,41 @@ Metal.@device_override function SpecialFunctions.erfc(x::Float32)
188188 end
189189end
190190
191+ #
192+ # Approximation to the error function.
193+ # Based on code from:
194+ # https://stackoverflow.com/questions/35148198/efficient-faithfully-rounded-implementation-of-error-function-erff#answer-35148199
195+ #
196+
197+ Metal. @device_override function SpecialFunctions. erfinv (a:: Float32 )
198+ t = fma (a, 0.0f0 - a, 1.0f0 )
199+ t = log (t)
200+
201+ if abs (t) > 6.125f0
202+ p = 3.03697567f-10
203+ p = fma (p, t, 2.93243101f-8 )
204+ p = fma (p, t, 1.22150334f-6 )
205+ p = fma (p, t, 2.84108955f-5 )
206+ p = fma (p, t, 3.93552968f-4 )
207+ p = fma (p, t, 3.02698812f-3 )
208+ p = fma (p, t, 4.83185798f-3 )
209+ p = fma (p, t, - 2.64646143f-1 )
210+ p = fma (p, t, 8.40016484f-1 )
211+ return a * p
212+ else
213+ p = 5.43877832f-9
214+ p = fma (p, t, 1.43285448f-7 )
215+ p = fma (p, t, 1.22774793f-6 )
216+ p = fma (p, t, 1.12963626f-7 )
217+ p = fma (p, t, - 5.61530760f-5 )
218+ p = fma (p, t, - 1.47697632f-4 )
219+ p = fma (p, t, 2.31468678f-3 )
220+ p = fma (p, t, 1.15392581f-2 )
221+ p = fma (p, t, - 2.32015476f-1 )
222+ p = fma (p, t, 8.86226892f-1 )
223+ return a * p
224+ end
225+ end
226+
227+
191228end # module
0 commit comments