|  | 
| 1 | 1 | using SpecialFunctions | 
|  | 2 | +using BFloat16s | 
| 2 | 3 | using Metal: metal_support | 
| 3 | 4 | 
 | 
| 4 | 5 | @testset "arguments" begin | 
|  | 
| 308 | 309 | ############################################################################################ | 
| 309 | 310 | 
 | 
| 310 | 311 | @testset "simd intrinsics" begin | 
| 311 |  | - | 
| 312 |  | -@testset "shuffle($typ)" for typ in [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8] | 
|  | 312 | +types = [Float32, Float16, Int32, UInt32, Int16, UInt16, Int8, UInt8] | 
|  | 313 | +metal_support() >= v"3.1" && push!(types, BFloat16) | 
|  | 314 | +@testset "shuffle($typ)" for typ in types | 
| 313 | 315 |     function kernel(a::MtlDeviceVector{T}, b::MtlDeviceVector{T}) where T | 
| 314 | 316 |         idx = thread_position_in_grid_1d() | 
| 315 | 317 |         idx_in_simd = thread_index_in_simdgroup() | 
|  | 
| 344 | 346 | end | 
| 345 | 347 | 
 | 
| 346 | 348 | @testset "matrix functions" begin | 
| 347 |  | -    @testset "load_store($typ)" for typ in [Float16, Float32] | 
|  | 349 | +    simdgroup_types = [Float16, Float32] | 
|  | 350 | +    metal_support() >= v"3.1" && push!(simdgroup_types, BFloat16) | 
|  | 351 | +    @testset "load_store($typ)" for typ in simdgroup_types | 
| 348 | 352 |         function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, | 
| 349 | 353 |                             origin_a=(1, 1), origin_b=(1, 1)) where {T} | 
| 350 | 354 |             sg_a = simdgroup_load(a, origin_a) | 
|  | 
| 367 | 371 |         end | 
| 368 | 372 |     end | 
| 369 | 373 | 
 | 
| 370 |  | -    @testset "load_store_tg($typ)" for typ in [Float16, Float32] | 
|  | 374 | +    @testset "load_store_tg($typ)" for typ in simdgroup_types | 
| 371 | 375 |         function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}) where {T} | 
| 372 | 376 |             pos = thread_position_in_threadgroup_2d() | 
| 373 | 377 | 
 | 
|  | 
| 391 | 395 |         @test Array(a) == Array(b) | 
| 392 | 396 |     end | 
| 393 | 397 | 
 | 
| 394 |  | -    @testset "mul($typ)" for typ in [Float16, Float32] | 
|  | 398 | +    @testset "mul($typ)" for typ in simdgroup_types | 
| 395 | 399 |         function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}) where {T} | 
| 396 | 400 |             sg_a = simdgroup_load(a) | 
| 397 | 401 |             sg_b = simdgroup_load(b) | 
|  | 
| 407 | 411 |         @test Array(a) * Array(b) ≈ Array(c) | 
| 408 | 412 |     end | 
| 409 | 413 | 
 | 
| 410 |  | -    @testset "mad($typ)" for typ in [Float16, Float32] | 
|  | 414 | +    @testset "mad($typ)" for typ in simdgroup_types | 
| 411 | 415 |         function kernel(a::MtlDeviceArray{T}, b::MtlDeviceArray{T}, c::MtlDeviceArray{T}, | 
| 412 | 416 |                     d::MtlDeviceArray{T}) where {T} | 
| 413 | 417 |             sg_a = simdgroup_load(a) | 
|  | 
0 commit comments