@@ -2,12 +2,15 @@ if capability(device()) >= v"7.0"
22
33using CUDA. WMMA
44
5+ using BFloat16s: BFloat16
6+
57map_ptx_to_jl_frag = Dict (
68 " u8" => reinterpret (Int32, UInt8 (42 ) * ones (UInt8, 4 ))[1 ],
79 " s8" => reinterpret (Int32, UInt8 (42 ) * ones (UInt8, 4 ))[1 ],
810 " u32" => UInt32 (42 ),
911 " s32" => Int32 (42 ),
1012 " f16" => ntuple (i -> VecElement {Float16} (42 ), 2 ),
13+ " bf16" => reinterpret (UInt32, BFloat16 (42 ) * ones (BFloat16, 2 ))[1 ],
1114 " f32" => Float32 (42 )
1215 )
1316# Return specific matrix shape given operation configuration
4851 startswith (elem_type, " u" ))
4952 continue
5053 end
54+ # Skip BFloat16 WMMA on pre-Ampere devices
55+ if capability (device ()) < v " 8.0" && elem_type == " bf16"
56+ continue
57+ end
5158
5259 shape = CUDA. WMMA. get_hl_shape (mnk[1 ], mnk[2 ], mnk[3 ])
5360
115122 startswith (elem_type, " u" ))
116123 continue
117124 end
125+ # Skip BFloat16 WMMA on pre-Ampere devices
126+ if capability (device ()) < v " 8.0" && elem_type == " bf16"
127+ continue
128+ end
118129
119130 shape = CUDA. WMMA. get_hl_shape (mnk[1 ], mnk[2 ], mnk[3 ])
120131
175186 startswith (ab_elem_type, " u" ))
176187 continue
177188 end
189+ # Skip BFloat16 WMMA on pre-Ampere devices
190+ if capability (device ()) < v " 8.0" && ab_elem_type == " bf16"
191+ continue
192+ end
178193
179194 # Type-dependent variables
180195 d_ty = CUDA. WMMA. map_ptx_to_jl_array[d_elem_type]
187202 lda_func = getfield (Main, Symbol (" llvm_wmma_load_a_$(a_layout) _$(shape) _global_stride_$(ab_elem_type) " ))
188203 ldb_func = getfield (Main, Symbol (" llvm_wmma_load_b_$(b_layout) _$(shape) _global_stride_$(ab_elem_type) " ))
189204 ldc_func = getfield (Main, Symbol (" llvm_wmma_load_c_col_$(shape) _global_stride_$(c_elem_type) " ))
190- # Account for half and int/subint mma different naming conventions
191- # Int/subint mma functions are distinguished by the a/b element type
192- mma_sym = d_ty == Int32 ? Symbol (" llvm_wmma_mma_$(a_layout) _$(b_layout) _$(shape) _$(ab_elem_type) " ) :
205+ # Account for half and int/subint/bf16 mma different naming conventions
206+ # Int/subint and bf16 mma functions are distinguished by the a/b element type
207+ mma_sym = ( d_ty == Int32 || ab_elem_type == " bf16 " ) ? Symbol (" llvm_wmma_mma_$(a_layout) _$(b_layout) _$(shape) _$(ab_elem_type) " ) :
193208 Symbol (" llvm_wmma_mma_$(a_layout) _$(b_layout) _$(shape) _$(d_elem_type) _$(c_elem_type) " )
194209 mma_func = getfield (Main, mma_sym)
195210 std_func = getfield (Main, Symbol (" llvm_wmma_store_d_col_$(shape) _global_stride_$(d_elem_type) " ))
227242 # Alter test depending on a/b element Type
228243 if ab_ty == Float16
229244 @test new_a * new_b + c ≈ Array (d_dev) rtol= Base. rtoldefault (Float16)
245+ elseif ab_ty == BFloat16
246+ @test Float32 .(new_a) * Float32 .(new_b) + c ≈ Array (d_dev) rtol= Base. rtoldefault (BFloat16)
230247 else # Cast a and b to prevent UInt8 rollover of resultant data
231248 @test Int32 .(new_a) * Int32 .(new_b) + c == Array (d_dev)
232249 end
@@ -256,12 +273,20 @@ end
256273 @test WMMA. unflatten (NTuple{8 , NTuple{2 , Int64}}, ntuple (i -> i, 2 * 8 )) == ntuple (i -> ntuple (j -> (i- 1 ) * 2 + j, 2 ), 8 )
257274 @test WMMA. unflatten (NTuple{8 , NTuple{2 , VecElement{Float16}}}, ntuple (i -> Float16 (i), 2 * 8 )) == ntuple (i -> ntuple (j -> VecElement {Float16} ((i- 1 ) * 2 + j), 2 ), 8 )
258275 end
276+
277+ @testset " BFloat16 packing/unpacking" begin
278+ bf_vals = ntuple (i -> BFloat16 (i), 8 )
279+ packed = WMMA. unflatten_bf16 (bf_vals)
280+ @test length (packed) == 4
281+ unpacked = WMMA. flatten_bf16 (packed)
282+ @test unpacked == bf_vals
283+ end
259284end
260285
261286# ###############################################################################
262287
263288@testset " Broadcasting over fragments: size=$sz , type=$ty " for sz = [1 , 2 , 5 ],
264- ty = [Float16, Float32]
289+ ty = [Float16, Float32, BFloat16 ]
265290 @test ty (5 ) .* Fragment {16, 16, 16, sz, ty, RowMajor, MatrixA} (ntuple (i -> ty (i), sz)) == Fragment {16, 16, 16, sz, ty, RowMajor, MatrixA} (ntuple (i -> ty (5 * i), sz))
266291 @test ty (5 ) .+ Fragment {16, 16, 16, sz, ty, RowMajor, MatrixA} (ntuple (i -> ty (i), sz)) == Fragment {16, 16, 16, sz, ty, RowMajor, MatrixA} (ntuple (i -> ty (5 + i), sz))
267292end
@@ -331,6 +356,126 @@ end
331356
332357# ###############################################################################
333358
359+ if capability (device ()) >= v " 8.0"
360+ @testset " CUDA C-style API (BFloat16)" begin
361+ @testset " $(do_mac ? " MAC" : " MUL" ) : A: $a_layout , B: $b_layout , C: $c_layout , D: $d_layout " for a_layout in [ColMajor, RowMajor],
362+ b_layout in [ColMajor, RowMajor],
363+ c_layout in [ColMajor, RowMajor],
364+ d_layout in [ColMajor, RowMajor],
365+ do_mac in [true , false ]
366+
367+ a = rand (BFloat16, (16 , 16 ))
368+ b = rand (BFloat16, (16 , 16 ))
369+ c = rand (Float32, (16 , 16 ))
370+ d = Array {Float32} (undef, (16 , 16 ))
371+
372+ a_dev = CuArray (a)
373+ b_dev = CuArray (b)
374+ c_dev = CuArray (c)
375+ d_dev = CuArray (d)
376+
377+ # Note: BFloat16 fragment broadcasting (alpha .* a_frag) requires native bf16
378+ # scalar ops which aren't available on all architectures, so we skip scaling
379+ @eval function kernel_bf16 (a_dev, b_dev, c_dev, d_dev)
380+ conf = Config{16 , 16 , 16 , Float32}
381+
382+ a_frag = load_a (pointer (a_dev), 16 , $ a_layout, conf)
383+ b_frag = load_b (pointer (b_dev), 16 , $ b_layout, conf)
384+
385+ if $ do_mac
386+ c_frag = load_c (pointer (c_dev), 16 , $ c_layout, conf)
387+ else
388+ c_frag = fill_c (Float32 (0 ), conf)
389+ end
390+
391+ d_frag = mma (a_frag, b_frag, c_frag, conf)
392+
393+ store_d (pointer (d_dev), d_frag, 16 , $ d_layout, conf)
394+
395+ return
396+ end
397+
398+ @cuda threads= 32 kernel_bf16 (a_dev, b_dev, c_dev, d_dev)
399+ d = Array (d_dev)
400+
401+ new_a = (a_layout == ColMajor) ? a : transpose (a)
402+ new_b = (b_layout == ColMajor) ? b : transpose (b)
403+ new_c = (c_layout == ColMajor) ? c : transpose (c)
404+ new_d = (d_layout == ColMajor) ? d : transpose (d)
405+
406+ if do_mac
407+ @test Float32 .(new_a) * Float32 .(new_b) + new_c ≈ new_d rtol= Base. rtoldefault (BFloat16)
408+ else
409+ @test Float32 .(new_a) * Float32 .(new_b) ≈ new_d rtol= Base. rtoldefault (BFloat16)
410+ end
411+ end
412+ end
413+ end
414+
415+ # BFloat16 fragment broadcasting requires native bf16 scalar ops (CC 8.9+)
416+ # On earlier architectures, frag[i] returns UInt32 (packed), causing type mismatch
417+ if capability (device ()) >= v " 8.9"
418+ @testset " CUDA C-style API (BFloat16 with scaling)" begin
419+ @testset " $(do_mac ? " MAC" : " MUL" ) : A: $a_layout , B: $b_layout , C: $c_layout , D: $d_layout " for a_layout in [ColMajor, RowMajor],
420+ b_layout in [ColMajor, RowMajor],
421+ c_layout in [ColMajor, RowMajor],
422+ d_layout in [ColMajor, RowMajor],
423+ do_mac in [true , false ]
424+
425+ a = rand (BFloat16, (16 , 16 ))
426+ b = rand (BFloat16, (16 , 16 ))
427+ c = rand (Float32, (16 , 16 ))
428+ d = Array {Float32} (undef, (16 , 16 ))
429+
430+ a_dev = CuArray (a)
431+ b_dev = CuArray (b)
432+ c_dev = CuArray (c)
433+ d_dev = CuArray (d)
434+
435+ alpha = rand (BFloat16)
436+ beta = rand (Float32)
437+
438+ @eval function kernel_bf16_scaled (a_dev, b_dev, c_dev, d_dev, alpha, beta)
439+ conf = Config{16 , 16 , 16 , Float32}
440+
441+ a_frag = load_a (pointer (a_dev), 16 , $ a_layout, conf)
442+ b_frag = load_b (pointer (b_dev), 16 , $ b_layout, conf)
443+
444+ if $ do_mac
445+ c_frag = load_c (pointer (c_dev), 16 , $ c_layout, conf)
446+ else
447+ c_frag = fill_c (Float32 (0 ), conf)
448+ end
449+
450+ a_frag = alpha .* a_frag
451+ c_frag = beta .* c_frag
452+
453+ d_frag = mma (a_frag, b_frag, c_frag, conf)
454+
455+ store_d (pointer (d_dev), d_frag, 16 , $ d_layout, conf)
456+
457+ return
458+ end
459+
460+ @cuda threads= 32 kernel_bf16_scaled (a_dev, b_dev, c_dev, d_dev, alpha, beta)
461+ d = Array (d_dev)
462+
463+ new_a = (a_layout == ColMajor) ? a : transpose (a)
464+ new_b = (b_layout == ColMajor) ? b : transpose (b)
465+ new_c = (c_layout == ColMajor) ? c : transpose (c)
466+ new_d = (d_layout == ColMajor) ? d : transpose (d)
467+
468+ if do_mac
469+ @test Float32 (alpha) * Float32 .(new_a) * Float32 .(new_b) + beta * new_c ≈ new_d rtol= Base. rtoldefault (BFloat16)
470+ else
471+ @test Float32 (alpha) * Float32 .(new_a) * Float32 .(new_b) ≈ new_d rtol= Base. rtoldefault (BFloat16)
472+ end
473+ end
474+ end
475+ end
476+
477+ # ###############################################################################
478+
334479@testset " Codegen addressing" begin
335480 @testset " Global" begin
336481 function kernel (d)
0 commit comments