|  | 
| 1 |  | -using SpecialFunctions | 
| 2 | 1 | using Metal: metal_support | 
|  | 2 | +using Random | 
|  | 3 | +using SpecialFunctions | 
| 3 | 4 | 
 | 
| 4 | 5 | @testset "arguments" begin | 
| 5 | 6 |     @on_device dispatch_quadgroups_per_threadgroup() | 
| @@ -103,71 +104,261 @@ end | 
| 103 | 104 | 
 | 
| 104 | 105 | ############################################################################################ | 
| 105 | 106 | 
 | 
|  | 107 | +MATH_INTR_FUNCS_1_ARG = [ | 
|  | 108 | +    # Common functions | 
|  | 109 | +    # saturate, # T saturate(T x) Clamp between 0.0 and 1.0 | 
|  | 110 | +    sign, # T sign(T x) returns 0.0 if x is NaN | 
|  | 111 | + | 
|  | 112 | +    # float math | 
|  | 113 | +    acos, # T acos(T x) | 
|  | 114 | +    asin, # T asin(T x) | 
|  | 115 | +    asinh, # T asinh(T x) | 
|  | 116 | +    atan, # T atan(T x) | 
|  | 117 | +    atanh, # T atanh(T x) | 
|  | 118 | +    ceil, # T ceil(T x) | 
|  | 119 | +    cos, # T cos(T x) | 
|  | 120 | +    cosh, # T cosh(T x) | 
|  | 121 | +    cospi, # T cospi(T x) | 
|  | 122 | +    exp, # T exp(T x) | 
|  | 123 | +    exp2, # T exp2(T x) | 
|  | 124 | +    exp10, # T exp10(T x) | 
|  | 125 | +    abs, #T [f]abs(T x) | 
|  | 126 | +    floor, # T floor(T x) | 
|  | 127 | +    Metal.fract, # T fract(T x) | 
|  | 128 | +    # ilogb, # Ti ilogb(T x) | 
|  | 129 | +    log, # T log(T x) | 
|  | 130 | +    log2, # T log2(T x) | 
|  | 131 | +    log10, # T log10(T x) | 
|  | 132 | +    # Metal.rint, # T rint(T x) # TODO: Add test. Not sure what the behaviour actually is | 
|  | 133 | +    round, # T round(T x) | 
|  | 134 | +    Metal.rsqrt, # T rsqrt(T x) | 
|  | 135 | +    sin, # T sin(T x) | 
|  | 136 | +    sinh, # T sinh(T x) | 
|  | 137 | +    sinpi, # T sinpi(T x) | 
|  | 138 | +    sqrt, # sqrt(T x) | 
|  | 139 | +    tan, # T tan(T x) | 
|  | 140 | +    tanh, # T tanh(T x) | 
|  | 141 | +    tanpi, # T tanpi(T x) | 
|  | 142 | +    trunc, # T trunc(T x) | 
|  | 143 | +] | 
|  | 144 | +Metal.rsqrt(x::Float16) = 1 / sqrt(x) | 
|  | 145 | +Metal.rsqrt(x::Float32) = 1 / sqrt(x) | 
|  | 146 | +Metal.fract(x::Float16) = mod(x, 1) | 
|  | 147 | +Metal.fract(x::Float32) = mod(x, 1) | 
|  | 148 | + | 
|  | 149 | +MATH_INTR_FUNCS_2_ARG = [ | 
|  | 150 | +    # Common function | 
|  | 151 | +    # step, # T step(T edge, T x) Returns 0.0 if x < edge, otherwise it returns 1.0 | 
|  | 152 | + | 
|  | 153 | +    # float math | 
|  | 154 | +    atan, # T atan2(T x, T y) Compute arc tangent of y over x. | 
|  | 155 | +    # fdim, # T fdim(T x, T y) | 
|  | 156 | +    max, # T [f]max(T x, T y) | 
|  | 157 | +    min, # T [f]min(T x, T y) | 
|  | 158 | +    # fmod, # T fmod(T x, T y) | 
|  | 159 | +    # frexp, # T frexp(T x, Ti &exponent) | 
|  | 160 | +    # ldexp, # T ldexp(T x, Ti k) | 
|  | 161 | +    # modf, # T modf(T x, T &intval) | 
|  | 162 | +    # nextafter, # T nextafter(T x, T y) # Metal 3.1+ | 
|  | 163 | +    # sincos, | 
|  | 164 | +    hypot, # NOT MSL but tested the same | 
|  | 165 | +] | 
|  | 166 | + | 
|  | 167 | +MATH_INTR_FUNCS_3_ARG = [ | 
|  | 168 | +    # Common functions | 
|  | 169 | +    # mix, # T mix(T x, T y, T a) # x+(y-x)*a | 
|  | 170 | +    # smoothstep, # T smoothstep(T edge0, T edge1, T x) | 
|  | 171 | +    fma, # T fma(T a, T b, T c) | 
|  | 172 | +    max, # T max3(T x, T y, T z) | 
|  | 173 | +    # median3, # T median3(T x, T y, T z) | 
|  | 174 | +    min, # T min3(T x, T y, T z) | 
|  | 175 | +] | 
|  | 176 | + | 
| 106 | 177 | @testset "math" begin | 
| 107 |  | -    a = ones(Float32,1) | 
| 108 |  | -    a .* Float32(3.14) | 
| 109 |  | -    bufferA = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a) | 
| 110 |  | -    vecA = unsafe_wrap(Vector{Float32}, pointer(bufferA), 1) | 
|  | 178 | +# 1-arg functions | 
|  | 179 | +@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16) | 
|  | 180 | +    cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt] | 
|  | 181 | +        rand(T, 4) | 
|  | 182 | +    else | 
|  | 183 | +        T[0.0, -0.0, rand(T), -rand(T)] | 
|  | 184 | +    end | 
|  | 185 | + | 
|  | 186 | +    mtlarr = MtlArray(cpuarr) | 
|  | 187 | + | 
|  | 188 | +    mtlout = fill!(similar(mtlarr), 0) | 
| 111 | 189 | 
 | 
| 112 |  | -    function intr_test(arr) | 
|  | 190 | +    function kernel(res, arr) | 
| 113 | 191 |         idx = thread_position_in_grid_1d() | 
| 114 |  | -        arr[idx] = cos(arr[idx]) | 
|  | 192 | +        res[idx] = fun(arr[idx]) | 
| 115 | 193 |         return nothing | 
| 116 | 194 |     end | 
| 117 |  | -    @metal intr_test(bufferA) | 
| 118 |  | -    synchronize() | 
| 119 |  | -    @test vecA ≈ cos.(a) | 
|  | 195 | +    Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr) | 
|  | 196 | +    @eval @test Array($mtlout) ≈ $fun.($cpuarr) | 
|  | 197 | +end | 
|  | 198 | +# 2-arg functions | 
|  | 199 | +@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG | 
|  | 200 | +    N = 4 | 
|  | 201 | +    arr1 = randn(T, N) | 
|  | 202 | +    arr2 = randn(T, N) | 
|  | 203 | +    mtlarr1 = MtlArray(arr1) | 
|  | 204 | +    mtlarr2 = MtlArray(arr2) | 
|  | 205 | + | 
|  | 206 | +    mtlout = fill!(similar(mtlarr1), 0) | 
| 120 | 207 | 
 | 
| 121 |  | -    function intr_test2(arr) | 
|  | 208 | +    function kernel(res, x, y) | 
| 122 | 209 |         idx = thread_position_in_grid_1d() | 
| 123 |  | -        arr[idx] = Metal.rsqrt(arr[idx]) | 
|  | 210 | +        res[idx] = fun(x[idx], y[idx]) | 
| 124 | 211 |         return nothing | 
| 125 | 212 |     end | 
| 126 |  | -    @metal intr_test2(bufferA) | 
| 127 |  | -    synchronize() | 
|  | 213 | +    Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) | 
|  | 214 | +    @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2) | 
|  | 215 | +end | 
|  | 216 | +# 3-arg functions | 
|  | 217 | +@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG | 
|  | 218 | +    N = 4 | 
|  | 219 | +    arr1 = randn(T, N) | 
|  | 220 | +    arr2 = randn(T, N) | 
|  | 221 | +    arr3 = randn(T, N) | 
| 128 | 222 | 
 | 
| 129 |  | -    bufferB = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a) | 
| 130 |  | -    vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1) | 
|  | 223 | +    mtlarr1 = MtlArray(arr1) | 
|  | 224 | +    mtlarr2 = MtlArray(arr2) | 
|  | 225 | +    mtlarr3 = MtlArray(arr3) | 
| 131 | 226 | 
 | 
| 132 |  | -    function intr_test3(arr_sin, arr_cos) | 
|  | 227 | +    mtlout = fill!(similar(mtlarr1), 0) | 
|  | 228 | + | 
|  | 229 | +    function kernel(res, x, y, z) | 
| 133 | 230 |         idx = thread_position_in_grid_1d() | 
| 134 |  | -        s, c = sincos(arr_cos[idx]) | 
| 135 |  | -        arr_sin[idx] = s | 
| 136 |  | -        arr_cos[idx] = c | 
|  | 231 | +        res[idx] = fun(x[idx], y[idx], z[idx]) | 
| 137 | 232 |         return nothing | 
| 138 | 233 |     end | 
|  | 234 | +    Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3) | 
|  | 235 | +    @eval @test Array($mtlout) ≈ $fun.($arr1, $arr2, $arr3) | 
|  | 236 | +end | 
|  | 237 | +end | 
| 139 | 238 | 
 | 
| 140 |  | -    @metal intr_test3(bufferA, bufferB) | 
| 141 |  | -    synchronize() | 
| 142 |  | -    @test vecA ≈ sin.(a) | 
| 143 |  | -    @test vecB ≈ cos.(a) | 
|  | 239 | +@testset "unique math" begin | 
|  | 240 | +@testset "$T" for T in (Float32, Float16) | 
|  | 241 | +    let # acosh | 
|  | 242 | +        arr = T[0, rand(T, 3)...] .+ T(1) | 
|  | 243 | +        buffer = MtlArray(arr) | 
|  | 244 | +        vec = acosh.(buffer) | 
|  | 245 | +        @test Array(vec) ≈ acosh.(arr) | 
|  | 246 | +    end | 
| 144 | 247 | 
 | 
| 145 |  | -    b = collect(LinRange(nextfloat(-1f0), 10f0, 20)) | 
| 146 |  | -    bufferC = MtlArray(b) | 
| 147 |  | -    vecC = Array(log1p.(bufferC)) | 
| 148 |  | -    @test vecC ≈ log1p.(b) | 
|  | 248 | +    let # sincos | 
|  | 249 | +        N = 4 | 
|  | 250 | +        arr = rand(T, N) | 
|  | 251 | +        bufferA = MtlArray(arr) | 
|  | 252 | +        bufferB = MtlArray(arr) | 
|  | 253 | +        function intr_test3(arr_sin, arr_cos) | 
|  | 254 | +            idx = thread_position_in_grid_1d() | 
|  | 255 | +            sinres, cosres = sincos(arr_cos[idx]) | 
|  | 256 | +            arr_sin[idx] = sinres | 
|  | 257 | +            arr_cos[idx] = cosres | 
|  | 258 | +            return nothing | 
|  | 259 | +        end | 
|  | 260 | +        # Broken with Float16 | 
|  | 261 | +        if T == Float16 | 
|  | 262 | +            @test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB) | 
|  | 263 | +        else | 
|  | 264 | +            Metal.@sync @metal threads = N intr_test3(bufferA, bufferB) | 
|  | 265 | +            @test Array(bufferA) ≈ sin.(arr) | 
|  | 266 | +            @test Array(bufferB) ≈ cos.(arr) | 
|  | 267 | +        end | 
|  | 268 | +    end | 
| 149 | 269 | 
 | 
|  | 270 | +    let # clamp | 
|  | 271 | +        N = 4 | 
|  | 272 | +        in = randn(T, N) | 
|  | 273 | +        minval = fill(T(-0.6), N) | 
|  | 274 | +        maxval = fill(T(0.6), N) | 
| 150 | 275 | 
 | 
| 151 |  | -    d = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) | 
| 152 |  | -    bufferD = MtlArray(d) | 
| 153 |  | -    vecD = Array(SpecialFunctions.erf.(bufferD)) | 
| 154 |  | -    @test vecD ≈ SpecialFunctions.erf.(d) | 
|  | 276 | +        mtlin = MtlArray(in) | 
|  | 277 | +        mtlminval = MtlArray(minval) | 
|  | 278 | +        mtlmaxval = MtlArray(maxval) | 
| 155 | 279 | 
 | 
|  | 280 | +        mtlout = fill!(similar(mtlin), 0) | 
|  | 281 | + | 
|  | 282 | +        function kernel(res, x, y, z) | 
|  | 283 | +            idx = thread_position_in_grid_1d() | 
|  | 284 | +            res[idx] = clamp(x[idx], y[idx], z[idx]) | 
|  | 285 | +            return nothing | 
|  | 286 | +        end | 
|  | 287 | +        Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval) | 
|  | 288 | +        @test Array(mtlout) == clamp.(in, minval, maxval) | 
|  | 289 | +    end | 
| 156 | 290 | 
 | 
| 157 |  | -    e = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) | 
| 158 |  | -    bufferE = MtlArray(e) | 
| 159 |  | -    vecE = Array(SpecialFunctions.erfc.(bufferE)) | 
| 160 |  | -    @test vecE ≈ SpecialFunctions.erfc.(e) | 
|  | 291 | +    let #pow | 
|  | 292 | +        N = 4 | 
|  | 293 | +        arr1 = rand(T, N) | 
|  | 294 | +        arr2 = rand(T, N) | 
|  | 295 | +        mtlarr1 = MtlArray(arr1) | 
|  | 296 | +        mtlarr2 = MtlArray(arr2) | 
| 161 | 297 | 
 | 
| 162 |  | -    f = collect(LinRange(-1f0, 1f0, 20)) | 
| 163 |  | -    bufferF = MtlArray(f) | 
| 164 |  | -    vecF = Array(SpecialFunctions.erfinv.(bufferF)) | 
| 165 |  | -    @test vecF ≈ SpecialFunctions.erfinv.(f) | 
|  | 298 | +        mtlout = fill!(similar(mtlarr1), 0) | 
| 166 | 299 | 
 | 
| 167 |  | -    f = collect(LinRange(nextfloat(-88f0), 88f0, 100)) | 
| 168 |  | -    bufferF = MtlArray(f) | 
| 169 |  | -    vecF = Array(expm1.(bufferF)) | 
| 170 |  | -    @test vecF ≈ expm1.(f) | 
|  | 300 | +        function kernel(res, x, y) | 
|  | 301 | +            idx = thread_position_in_grid_1d() | 
|  | 302 | +            res[idx] = x[idx]^y[idx] | 
|  | 303 | +            return nothing | 
|  | 304 | +        end | 
|  | 305 | +        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) | 
|  | 306 | +        @test Array(mtlout) ≈ arr1 .^ arr2 | 
|  | 307 | +    end | 
|  | 308 | + | 
|  | 309 | +    let #powr | 
|  | 310 | +        N = 4 | 
|  | 311 | +        arr1 = rand(T, N) | 
|  | 312 | +        arr2 = rand(T, N) | 
|  | 313 | +        mtlarr1 = MtlArray(arr1) | 
|  | 314 | +        mtlarr2 = MtlArray(arr2) | 
|  | 315 | + | 
|  | 316 | +        mtlout = fill!(similar(mtlarr1), 0) | 
|  | 317 | + | 
|  | 318 | +        function kernel(res, x, y) | 
|  | 319 | +            idx = thread_position_in_grid_1d() | 
|  | 320 | +            res[idx] = Metal.powr(x[idx], y[idx]) | 
|  | 321 | +            return nothing | 
|  | 322 | +        end | 
|  | 323 | +        Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2) | 
|  | 324 | +        @test Array(mtlout) ≈ arr1 .^ arr2 | 
|  | 325 | +    end | 
|  | 326 | + | 
|  | 327 | +    let # log1p | 
|  | 328 | +        arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20)) | 
|  | 329 | +        buffer = MtlArray(arr) | 
|  | 330 | +        vec = Array(log1p.(buffer)) | 
|  | 331 | +        @test vec ≈ log1p.(arr) | 
|  | 332 | +    end | 
|  | 333 | + | 
|  | 334 | +    let # erf | 
|  | 335 | +        arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) | 
|  | 336 | +        buffer = MtlArray(arr) | 
|  | 337 | +        vec = Array(SpecialFunctions.erf.(buffer)) | 
|  | 338 | +        @test vec ≈ SpecialFunctions.erf.(arr) | 
|  | 339 | +    end | 
|  | 340 | + | 
|  | 341 | +    let # erfc | 
|  | 342 | +        arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20)) | 
|  | 343 | +        buffer = MtlArray(arr) | 
|  | 344 | +        vec = Array(SpecialFunctions.erfc.(buffer)) | 
|  | 345 | +        @test vec ≈ SpecialFunctions.erfc.(arr) | 
|  | 346 | +    end | 
|  | 347 | + | 
|  | 348 | +    let # erfinv | 
|  | 349 | +        arr = collect(LinRange(-1.0f0, 1.0f0, 20)) | 
|  | 350 | +        buffer = MtlArray(arr) | 
|  | 351 | +        vec = Array(SpecialFunctions.erfinv.(buffer)) | 
|  | 352 | +        @test vec ≈ SpecialFunctions.erfinv.(arr) | 
|  | 353 | +    end | 
|  | 354 | + | 
|  | 355 | +    let # expm1 | 
|  | 356 | +        arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100)) | 
|  | 357 | +        buffer = MtlArray(arr) | 
|  | 358 | +        vec = Array(expm1.(buffer)) | 
|  | 359 | +        @test vec ≈ expm1.(arr) | 
|  | 360 | +    end | 
|  | 361 | +end | 
| 171 | 362 | end | 
| 172 | 363 | 
 | 
| 173 | 364 | ############################################################################################ | 
|  | 
0 commit comments