|
1 | 1 | using SpecialFunctions |
| 2 | +using BFloat16s |
2 | 3 | using Metal: metal_support |
3 | 4 |
|
4 | 5 | @testset "arguments" begin |
|
308 | 309 | ############################################################################################ |
309 | 310 |
|
310 | 311 | @testset "simd intrinsics" begin |
311 | | - |
312 | | -@testset "shuffle($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8] |
| 312 | +types = [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8] |
| 313 | +metal_support() >= v"3.1" && push!(types, BFloat16) |
| 314 | +@testset "shuffle($typ)" for typ in types |
313 | 315 | function kernel(a::MtlDeviceVector{T}, b::MtlDeviceVector{T}) where T |
314 | 316 | idx = thread_position_in_grid_1d() |
315 | 317 | idx_in_simd = thread_index_in_simdgroup() |
|
344 | 346 | end |
345 | 347 |
|
346 | 348 | @testset "matrix functions" begin |
347 | | - @testset "load_store($typ)" for typ in [Float16, Float32] |
| 349 | + simdgroup_types = [Float16, Float32] |
| 350 | + metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16) |
| 351 | + @testset "load_store($typ)" for typ in simdgroup_types |
348 | 352 | function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, |
349 | 353 | origin_a=(1, 1), origin_b=(1, 1)) where {T} |
350 | 354 | sg_a = simdgroup_load(a, origin_a) |
|
367 | 371 | end |
368 | 372 | end |
369 | 373 |
|
370 | | - @testset "load_store_tg($typ)" for typ in [Float16, Float32] |
| 374 | + @testset "load_store_tg($typ)" for typ in simdgroup_types |
371 | 375 | function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T} |
372 | 376 | pos = thread_position_in_threadgroup_2d() |
373 | 377 |
|
|
391 | 395 | @test Array(a) == Array(b) |
392 | 396 | end |
393 | 397 |
|
394 | | - @testset "mul($typ)" for typ in [Float16, Float32] |
| 398 | + @testset "mul($typ)" for typ in simdgroup_types |
395 | 399 | function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T} |
396 | 400 | sg_a = simdgroup_load(a) |
397 | 401 | sg_b = simdgroup_load(b) |
|
407 | 411 | @test Array(a) * Array(b) ≈ Array(c) |
408 | 412 | end |
409 | 413 |
|
410 | | - @testset "mad($typ)" for typ in [Float16, Float32] |
| 414 | + @testset "mad($typ)" for typ in simdgroup_types |
411 | 415 | function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}, |
412 | 416 | d::MtlDeviceArray{T}) where {T} |
413 | 417 | sg_a = simdgroup_load(a) |
|
0 commit comments