|
1 | | -using SpecialFunctions |
2 | 1 | using Metal: metal_support |
| 2 | +using Random |
| 3 | +using SpecialFunctions |
3 | 4 |
|
4 | 5 | @testset "arguments" begin |
5 | 6 | @on_device dispatch_quadgroups_per_threadgroup() |
@@ -103,71 +104,219 @@ end |
103 | 104 |
|
104 | 105 | ############################################################################################ |
105 | 106 |
|
| 107 | +MATH_INTR_FUNCS_1_ARG = [ |
| 108 | + abs, |
| 109 | + acos, |
| 110 | + # acosh, # not defined for values < 1, tested separately |
| 111 | + asin, |
| 112 | + asinh, |
| 113 | + atan, |
| 114 | + atanh, |
| 115 | + ceil, |
| 116 | + cos, |
| 117 | + cosh, |
| 118 | + cospi, |
| 119 | + exp, |
| 120 | + exp2, |
| 121 | + exp10, |
| 122 | + floor, |
| 123 | + Metal.fract, |
| 124 | + log, |
| 125 | + log2, |
| 126 | + log10, |
| 127 | + # Metal.rint, # not sure what the behaviour actually is |
| 128 | + round, |
| 129 | + Metal.rsqrt, |
| 130 | + sin, |
| 131 | + sinh, |
| 132 | + sinpi, |
| 133 | + sqrt, |
| 134 | + tan, |
| 135 | + tanh, |
| 136 | + tanpi, |
| 137 | + trunc, |
| 138 | +] |
| 139 | +Metal.rsqrt(x::Float16) = 1 / sqrt(x) |
| 140 | +Metal.rsqrt(x::Float32) = 1 / sqrt(x) |
| 141 | +Metal.fract(x::Float16) = mod(x, 1) |
| 142 | +Metal.fract(x::Float32) = mod(x, 1) |
| 143 | + |
| 144 | +MATH_INTR_FUNCS_2_ARG = [ |
| 145 | + min, |
| 146 | + max, |
| 147 | + pow, # :(^), |
| 148 | + Metal.powr, |
| 149 | + hypot, |
| 150 | +] |
| 151 | + |
| 152 | +MATH_INTR_FUNCS_3_ARG = [ |
| 153 | + fma, |
| 154 | +] |
| 155 | + |
106 | 156 | @testset "math" begin |
107 | | - a = ones(Float32,1) |
108 | | - a .* Float32(3.14) |
109 | | - bufferA = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a) |
110 | | - vecA = unsafe_wrap(Vector{Float32}, pointer(bufferA), 1) |
| 157 | +# 1-arg functions |
| 158 | +@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16) |
| 159 | + cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt] |
| 160 | + rand(T, 4) |
| 161 | + else |
| 162 | + T[0.0, -0.0, rand(T), -rand(T)] |
| 163 | + end |
| 164 | + |
| 165 | + mtlarr = MtlArray(cpuarr) |
111 | 166 |
|
112 | | - function intr_test(arr) |
| 167 | + mtlout = fill!(similar(mtlarr), 0) |
| 168 | + |
| 169 | + function kernel(res, arr) |
113 | 170 | idx = thread_position_in_grid_1d() |
114 | | - arr[idx] = cos(arr[idx]) |
| 171 | + res[idx] = fun(arr[idx]) |
115 | 172 | return nothing |
116 | 173 | end |
117 | | - @metal intr_test(bufferA) |
118 | | - synchronize() |
119 | | - @test vecA ≈ cos.(a) |
| 174 | + Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr) |
| 175 | + @eval @test Array($mtlout) ≈ $fun.($cpuarr) |
| 176 | +end |
| 177 | +# 2-arg functions |
| 178 | +@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG |
| 179 | + N = 4 |
| 180 | + arr1 = randn(T, N) |
| 181 | + arr2 = randn(T, N) |
| 182 | + mtlarr1 = MtlArray(arr1) |
| 183 | + mtlarr2 = MtlArray(arr2) |
| 184 | + |
| 185 | + mtlout = fill!(similar(mtlarr1), 0) |
120 | 186 |
|
121 | | - function intr_test2(arr) |
| 187 | + function kernel(res, x, y) |
122 | 188 | idx = thread_position_in_grid_1d() |
123 | | - arr[idx] = Metal.rsqrt(arr[idx]) |
| 189 | + res[idx] = fun(x[idx], y[idx]) |
124 | 190 | return nothing |
125 | 191 | end |
126 | | - @metal intr_test2(bufferA) |
127 | | - synchronize() |
| 192 | + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) |
| 193 | + @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2) |
| 194 | +end |
| 195 | +# 3-arg functions |
| 196 | +@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG |
| 197 | + N = 4 |
| 198 | + arr1 = randn(T, N) |
| 199 | + arr2 = randn(T, N) |
| 200 | + arr3 = randn(T, N) |
| 201 | + |
| 202 | + mtlarr1 = MtlArray(arr1) |
| 203 | + mtlarr2 = MtlArray(arr2) |
| 204 | + mtlarr3 = MtlArray(arr3) |
128 | 205 |
|
129 | | - bufferB = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a) |
130 | | - vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1) |
| 206 | + mtlout = fill!(similar(mtlarr1), 0) |
131 | 207 |
|
132 | | - function intr_test3(arr_sin, arr_cos) |
| 208 | + function kernel(res, x, y, z) |
133 | 209 | idx = thread_position_in_grid_1d() |
134 | | - s, c = sincos(arr_cos[idx]) |
135 | | - arr_sin[idx] = s |
136 | | - arr_cos[idx] = c |
| 210 | + res[idx] = fun(x[idx], y[idx], z[idx]) |
137 | 211 | return nothing |
138 | 212 | end |
| 213 | + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3) |
| 214 | + @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3) |
| 215 | +end |
| 216 | +end |
139 | 217 |
|
140 | | - @metal intr_test3(bufferA, bufferB) |
141 | | - synchronize() |
142 | | - @test vecA ≈ sin.(a) |
143 | | - @test vecB ≈ cos.(a) |
| 218 | +@testset "unique math" begin |
| 219 | +@testset "$T" for T in (Float32, Float16) |
| 220 | + let # acosh |
| 221 | + arr = T[0, rand(T, 3)...] .+ T(1) |
| 222 | + buffer = MtlArray(arr) |
| 223 | + vec = acosh.(buffer) |
| 224 | + @test Array(vec) ≈ acosh.(arr) |
| 225 | + end |
144 | 226 |
|
145 | | - b = collect(LinRange(nextfloat(-1f0), 10f0, 20)) |
146 | | - bufferC = MtlArray(b) |
147 | | - vecC = Array(log1p.(bufferC)) |
148 | | - @test vecC ≈ log1p.(b) |
| 227 | + let # sincos |
| 228 | + N = 4 |
| 229 | + arr = rand(T, N) |
| 230 | + bufferA = MtlArray(arr) |
| 231 | + bufferB = MtlArray(arr) |
| 232 | + function intr_test3(arr_sin, arr_cos) |
| 233 | + idx = thread_position_in_grid_1d() |
| 234 | + sinres, cosres = sincos(arr_cos[idx]) |
| 235 | + arr_sin[idx] = sinres |
| 236 | + arr_cos[idx] = cosres |
| 237 | + return nothing |
| 238 | + end |
| 239 | + # Broken with Float16 |
| 240 | + if T == Float16 |
| 241 | + @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB) |
| 242 | + else |
| 243 | + Metal.@sync @metal threads = N intr_test3(bufferA, bufferB) |
| 244 | + @test Array(bufferA) ≈ sin.(arr) |
| 245 | + @test Array(bufferB) ≈ cos.(arr) |
| 246 | + end |
| 247 | + end |
149 | 248 |
|
| 249 | + let #pow |
| 250 | + N = 4 |
| 251 | + arr1 = rand(T, N) |
| 252 | + arr2 = rand(T, N) |
| 253 | + mtlarr1 = MtlArray(arr1) |
| 254 | + mtlarr2 = MtlArray(arr2) |
150 | 255 |
|
151 | | - d = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) |
152 | | - bufferD = MtlArray(d) |
153 | | - vecD = Array(SpecialFunctions.erf.(bufferD)) |
154 | | - @test vecD ≈ SpecialFunctions.erf.(d) |
| 256 | + mtlout = fill!(similar(mtlarr1), 0) |
155 | 257 |
|
| 258 | + function kernel(res, x, y) |
| 259 | + idx = thread_position_in_grid_1d() |
| 260 | + res[idx] = x[idx]^y[idx] |
| 261 | + return nothing |
| 262 | + end |
| 263 | + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) |
| 264 | + @test Array(mtlout) ≈ arr1 .^ arr2 |
| 265 | + end |
| 266 | + |
| 267 | + let #powr |
| 268 | + N = 4 |
| 269 | + arr1 = rand(T, N) |
| 270 | + arr2 = rand(T, N) |
| 271 | + mtlarr1 = MtlArray(arr1) |
| 272 | + mtlarr2 = MtlArray(arr2) |
| 273 | + |
| 274 | + mtlout = fill!(similar(mtlarr1), 0) |
| 275 | + |
| 276 | + function kernel(res, x, y) |
| 277 | + idx = thread_position_in_grid_1d() |
| 278 | + res[idx] = Metal.powr(x[idx], y[idx]) |
| 279 | + return nothing |
| 280 | + end |
| 281 | + Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) |
| 282 | + @test Array(mtlout) ≈ arr1 .^ arr2 |
| 283 | + end |
156 | 284 |
|
157 | | - e = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) |
158 | | - bufferE = MtlArray(e) |
159 | | - vecE = Array(SpecialFunctions.erfc.(bufferE)) |
160 | | - @test vecE ≈ SpecialFunctions.erfc.(e) |
| 285 | + let # log1p |
| 286 | + arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)) |
| 287 | + buffer = MtlArray(arr) |
| 288 | + vec = Array(log1p.(buffer)) |
| 289 | + @test vec ≈ log1p.(arr) |
| 290 | + end |
161 | 291 |
|
162 | | - f = collect(LinRange(-1f0, 1f0, 20)) |
163 | | - bufferF = MtlArray(f) |
164 | | - vecF = Array(SpecialFunctions.erfinv.(bufferF)) |
165 | | - @test vecF ≈ SpecialFunctions.erfinv.(f) |
| 292 | + let # erf |
| 293 | + arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) |
| 294 | + buffer = MtlArray(arr) |
| 295 | + vec = Array(SpecialFunctions.erf.(buffer)) |
| 296 | + @test vec ≈ SpecialFunctions.erf.(arr) |
| 297 | + end |
| 298 | + |
| 299 | + let # erfc |
| 300 | + arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) |
| 301 | + buffer = MtlArray(arr) |
| 302 | + vec = Array(SpecialFunctions.erfc.(buffer)) |
| 303 | + @test vec ≈ SpecialFunctions.erfc.(arr) |
| 304 | + end |
166 | 305 |
|
167 | | - f = collect(LinRange(nextfloat(-88f0), 88f0, 100)) |
168 | | - bufferF = MtlArray(f) |
169 | | - vecF = Array(expm1.(bufferF)) |
170 | | - @test vecF ≈ expm1.(f) |
| 306 | + let # erfinv |
| 307 | + arr = collect(LinRange(-1.0f0, 1.0f0, 20)) |
| 308 | + buffer = MtlArray(arr) |
| 309 | + vec = Array(SpecialFunctions.erfinv.(buffer)) |
| 310 | + @test vec ≈ SpecialFunctions.erfinv.(arr) |
| 311 | + end |
| 312 | + |
| 313 | + let # expm1 |
| 314 | + arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)) |
| 315 | + buffer = MtlArray(arr) |
| 316 | + vec = Array(expm1.(buffer)) |
| 317 | + @test vec ≈ expm1.(arr) |
| 318 | + end |
| 319 | +end |
171 | 320 | end |
172 | 321 |
|
173 | 322 | ############################################################################################ |
|
0 commit comments