Skip to content

Commit 9a7cbd2

Browse files
authored
Add BFloat16 WMMA (#3009)
1 parent 554031a commit 9a7cbd2

File tree

2 files changed

+203
-16
lines changed

2 files changed

+203
-16
lines changed

src/device/intrinsics/wmma.jl

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ module WMMA
44
import ..LLVM
55
using ..CUDA: AS
66
using Core: LLVMPtr
7+
using BFloat16s: BFloat16
78

89
################################################################################
910
# CONSTANTS
@@ -15,6 +16,7 @@ const map_ptx_to_jl_array = Dict(
1516
"s8" => Int8,
1617
"s32" => Int32,
1718
"f16" => Float16,
19+
"bf16" => BFloat16,
1820
"f32" => Float32
1921
)
2022

@@ -24,6 +26,7 @@ const map_ptx_to_jl_frag = Dict(
2426
"s8" => UInt32,
2527
"s32" => Int32,
2628
"f16" => NTuple{2, VecElement{Float16}},
29+
"bf16" => UInt32,
2730
"f32" => Float32
2831
)
2932

@@ -41,6 +44,10 @@ const map_frag_sizes = Dict(
4144
"a.f16.m16n16k16" => 8,
4245
"a.f16.m8n32k16" => 8,
4346
"a.f16.m32n8k16" => 8,
47+
48+
"a.bf16.m16n16k16" => 4,
49+
"a.bf16.m8n32k16" => 2,
50+
"a.bf16.m32n8k16" => 8,
4451
# B
4552
"b.u8.m16n16k16" => 2,
4653
"b.u8.m8n32k16" => 4,
@@ -53,6 +60,10 @@ const map_frag_sizes = Dict(
5360
"b.f16.m16n16k16" => 8,
5461
"b.f16.m8n32k16" => 8,
5562
"b.f16.m32n8k16" => 8,
63+
64+
"b.bf16.m16n16k16" => 4,
65+
"b.bf16.m8n32k16" => 8,
66+
"b.bf16.m32n8k16" => 2,
5667
# C
5768
"c.s32.m16n16k16" => 8,
5869
"c.s32.m8n32k16" => 8,
@@ -96,10 +107,13 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f
96107
const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"]
97108
const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
98109
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]
110+
# BFloat16 (requires Ampere+, only f32 accumulator supported)
111+
const ldst_bf16_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["bf16"]
112+
const wmma_bf16_ops = [(16,16,16), (32,8,16), (8,32,16)], ["bf16"], ["f32"], ["f32"]
99113

100114
const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
101-
ldst_int_ab_ops, ldst_int_cd_ops)
102-
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops)
115+
ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops)
116+
const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops)
103117

104118
# Valid WMMA operation shapes
105119
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
@@ -319,12 +333,12 @@ for ops in all_wmma_ops,
319333
shape = get_hl_shape(mnk[1], mnk[2], mnk[3])
320334

321335
# Name of the LLVM intrinsic
322-
# If integer/sub-byte/bit A/B types, name is determined by A/B types
323-
if d_elem_type == "s32"
336+
# If integer/sub-byte/bit/bf16 A/B types, name is determined by A/B types
337+
if d_elem_type == "s32" || a_elem_type == "bf16"
324338
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type"
325339
# Name of the Julia wrapper function
326340
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_"))
327-
else # Name defined by D/C types
341+
else # f16: Name defined by D/C types
328342
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type"
329343
# Name of the Julia wrapper function
330344
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_"))
@@ -393,6 +407,28 @@ end
393407
@generated flatten(x::typ) where typ = Expr(:tuple, flatten_recurse(typ, :x)...)
394408
@generated unflatten(::Type{typ}, x) where typ = unflatten_recurse(typ, :x, 1)[1]
395409

410+
# BFloat16 packing/unpacking (UInt32 contains 2x BFloat16)
411+
@inline function unpack_bf16(x::UInt32)
412+
lo = reinterpret(BFloat16, UInt16(x & 0xFFFF))
413+
hi = reinterpret(BFloat16, UInt16(x >> 16))
414+
return (lo, hi)
415+
end
416+
417+
@inline function pack_bf16(lo::BFloat16, hi::BFloat16)
418+
return UInt32(reinterpret(UInt16, lo)) | (UInt32(reinterpret(UInt16, hi)) << 16)
419+
end
420+
421+
@inline function flatten_bf16(x::NTuple{N, UInt32}) where N
422+
ntuple(i -> begin
423+
lo, hi = unpack_bf16(x[(i+1)÷2])
424+
isodd(i) ? lo : hi
425+
end, Val(2N))
426+
end
427+
428+
@inline function unflatten_bf16(x::NTuple{N, BFloat16}) where N
429+
ntuple(i -> pack_bf16(x[2i-1], x[2i]), Val(N÷2))
430+
end
431+
396432
################################################################################
397433
# HIGH LEVEL (CUDA-STYLE API)
398434
################################################################################
@@ -513,6 +549,8 @@ const map_layout_ty_to_str = Dict(
513549
const map_num_elems = Dict(
514550
("a", Float16) => 16,
515551
("b", Float16) => 16,
552+
("a", BFloat16) => 8,
553+
("b", BFloat16) => 8,
516554
("c", Float16) => 8,
517555
("c", Float32) => 8,
518556
("d", Float16) => 8,
@@ -614,8 +652,9 @@ for mat in ["a", "b", "c"]
614652
# Name of the Julia wrapper
615653
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_"))
616654

655+
_flatten = T == BFloat16 ? flatten_bf16 : flatten
617656
return quote
618-
x = flatten($wrapper(addr, stride))
657+
x = $_flatten($wrapper(addr, stride))
619658
return Fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x)
620659
end
621660
end
@@ -656,19 +695,22 @@ mma
656695
b_layout = get_hl_layout(B_L)
657696
shape = get_hl_shape(M, N, K)
658697

659-
_, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T, shape)
698+
_, a_frag_sz, a_frag_ty, a_arr_str = get_hl_frag_info("a", A_T, shape)
660699
_, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape)
661700
_, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape)
662701
d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape)
663702

703+
names = ["llvm", "wmma", "mma", a_layout, b_layout, shape]
704+
# bf16 uses input type in intrinsic name, f16 uses d/c types
705+
A_T === BFloat16 ? push!(names, a_arr_str) : push!(names, d_arr_str, c_arr_str)
706+
wrapper = Symbol(join(filter(!isempty, names), "_"))
664707

665-
666-
# Name of the Julia wrapper
667-
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_"))
708+
a_unfl_expr = A_T === BFloat16 ? :(unflatten_bf16(a.x)) : :(unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x))
709+
b_unfl_expr = B_T === BFloat16 ? :(unflatten_bf16(b.x)) : :(unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x))
668710

669711
return quote
670-
a_unfl = unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x)
671-
b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x)
712+
a_unfl = $a_unfl_expr
713+
b_unfl = $b_unfl_expr
672714
c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x)
673715

674716
x = flatten($wrapper(a_unfl, b_unfl, c_unfl))

test/core/device/intrinsics/wmma.jl

Lines changed: 149 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@ if capability(device()) >= v"7.0"
22

33
using CUDA.WMMA
44

5+
using BFloat16s: BFloat16
6+
57
map_ptx_to_jl_frag = Dict(
68
"u8" => reinterpret(Int32, UInt8(42) * ones(UInt8, 4))[1],
79
"s8" => reinterpret(Int32, UInt8(42) * ones(UInt8, 4))[1],
810
"u32" => UInt32(42),
911
"s32" => Int32(42),
1012
"f16" => ntuple(i -> VecElement{Float16}(42), 2),
13+
"bf16" => reinterpret(UInt32, BFloat16(42) * ones(BFloat16, 2))[1],
1114
"f32" => Float32(42)
1215
)
1316
# Return specific matrix shape given operation configuration
@@ -48,6 +51,10 @@ end
4851
startswith(elem_type, "u"))
4952
continue
5053
end
54+
# Skip BFloat16 WMMA on pre-Ampere devices
55+
if capability(device()) < v"8.0" && elem_type == "bf16"
56+
continue
57+
end
5158

5259
shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])
5360

@@ -115,6 +122,10 @@ end
115122
startswith(elem_type, "u"))
116123
continue
117124
end
125+
# Skip BFloat16 WMMA on pre-Ampere devices
126+
if capability(device()) < v"8.0" && elem_type == "bf16"
127+
continue
128+
end
118129

119130
shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])
120131

@@ -175,6 +186,10 @@ end
175186
startswith(ab_elem_type, "u"))
176187
continue
177188
end
189+
# Skip BFloat16 WMMA on pre-Ampere devices
190+
if capability(device()) < v"8.0" && ab_elem_type == "bf16"
191+
continue
192+
end
178193

179194
# Type-dependent variables
180195
d_ty = CUDA.WMMA.map_ptx_to_jl_array[d_elem_type]
@@ -187,9 +202,9 @@ end
187202
lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_$(shape)_global_stride_$(ab_elem_type)"))
188203
ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_$(shape)_global_stride_$(ab_elem_type)"))
189204
ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_$(shape)_global_stride_$(c_elem_type)"))
190-
# Account for half and int/subint mma different naming conventions
191-
# Int/subint mma functions are distinguished by the a/b element type
192-
mma_sym = d_ty == Int32 ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
205+
# Account for half and int/subint/bf16 mma different naming conventions
206+
# Int/subint and bf16 mma functions are distinguished by the a/b element type
207+
mma_sym = (d_ty == Int32 || ab_elem_type == "bf16") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
193208
Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)")
194209
mma_func = getfield(Main, mma_sym)
195210
std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)"))
@@ -227,6 +242,8 @@ end
227242
# Alter test depending on a/b element Type
228243
if ab_ty == Float16
229244
@test new_a * new_b + c Array(d_dev) rtol=Base.rtoldefault(Float16)
245+
elseif ab_ty == BFloat16
246+
@test Float32.(new_a) * Float32.(new_b) + c Array(d_dev) rtol=Base.rtoldefault(BFloat16)
230247
else # Cast a and b to prevent UInt8 rollover of resultant data
231248
@test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev)
232249
end
@@ -256,12 +273,20 @@ end
256273
@test WMMA.unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8)
257274
@test WMMA.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)
258275
end
276+
277+
@testset "BFloat16 packing/unpacking" begin
278+
bf_vals = ntuple(i -> BFloat16(i), 8)
279+
packed = WMMA.unflatten_bf16(bf_vals)
280+
@test length(packed) == 4
281+
unpacked = WMMA.flatten_bf16(packed)
282+
@test unpacked == bf_vals
283+
end
259284
end
260285

261286
################################################################################
262287

263288
@testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5],
264-
ty = [Float16, Float32]
289+
ty = [Float16, Float32, BFloat16]
265290
@test ty(5) .* Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 * i), sz))
266291
@test ty(5) .+ Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 + i), sz))
267292
end
@@ -331,6 +356,126 @@ end
331356

332357
################################################################################
333358

359+
if capability(device()) >= v"8.0"
360+
@testset "CUDA C-style API (BFloat16)" begin
361+
@testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
362+
b_layout in [ColMajor, RowMajor],
363+
c_layout in [ColMajor, RowMajor],
364+
d_layout in [ColMajor, RowMajor],
365+
do_mac in [true, false]
366+
367+
a = rand(BFloat16, (16, 16))
368+
b = rand(BFloat16, (16, 16))
369+
c = rand(Float32, (16, 16))
370+
d = Array{Float32}(undef, (16, 16))
371+
372+
a_dev = CuArray(a)
373+
b_dev = CuArray(b)
374+
c_dev = CuArray(c)
375+
d_dev = CuArray(d)
376+
377+
# Note: BFloat16 fragment broadcasting (alpha .* a_frag) requires native bf16
378+
# scalar ops which aren't available on all architectures, so we skip scaling
379+
@eval function kernel_bf16(a_dev, b_dev, c_dev, d_dev)
380+
conf = Config{16, 16, 16, Float32}
381+
382+
a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
383+
b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
384+
385+
if $do_mac
386+
c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
387+
else
388+
c_frag = fill_c(Float32(0), conf)
389+
end
390+
391+
d_frag = mma(a_frag, b_frag, c_frag, conf)
392+
393+
store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
394+
395+
return
396+
end
397+
398+
@cuda threads=32 kernel_bf16(a_dev, b_dev, c_dev, d_dev)
399+
d = Array(d_dev)
400+
401+
new_a = (a_layout == ColMajor) ? a : transpose(a)
402+
new_b = (b_layout == ColMajor) ? b : transpose(b)
403+
new_c = (c_layout == ColMajor) ? c : transpose(c)
404+
new_d = (d_layout == ColMajor) ? d : transpose(d)
405+
406+
if do_mac
407+
@test Float32.(new_a) * Float32.(new_b) + new_c new_d rtol=Base.rtoldefault(BFloat16)
408+
else
409+
@test Float32.(new_a) * Float32.(new_b) new_d rtol=Base.rtoldefault(BFloat16)
410+
end
411+
end
412+
end
413+
end
414+
415+
# BFloat16 fragment broadcasting requires native bf16 scalar ops (CC 8.9+)
416+
# On earlier architectures, frag[i] returns UInt32 (packed), causing type mismatch
417+
if capability(device()) >= v"8.9"
418+
@testset "CUDA C-style API (BFloat16 with scaling)" begin
419+
@testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
420+
b_layout in [ColMajor, RowMajor],
421+
c_layout in [ColMajor, RowMajor],
422+
d_layout in [ColMajor, RowMajor],
423+
do_mac in [true, false]
424+
425+
a = rand(BFloat16, (16, 16))
426+
b = rand(BFloat16, (16, 16))
427+
c = rand(Float32, (16, 16))
428+
d = Array{Float32}(undef, (16, 16))
429+
430+
a_dev = CuArray(a)
431+
b_dev = CuArray(b)
432+
c_dev = CuArray(c)
433+
d_dev = CuArray(d)
434+
435+
alpha = rand(BFloat16)
436+
beta = rand(Float32)
437+
438+
@eval function kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
439+
conf = Config{16, 16, 16, Float32}
440+
441+
a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
442+
b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
443+
444+
if $do_mac
445+
c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
446+
else
447+
c_frag = fill_c(Float32(0), conf)
448+
end
449+
450+
a_frag = alpha .* a_frag
451+
c_frag = beta .* c_frag
452+
453+
d_frag = mma(a_frag, b_frag, c_frag, conf)
454+
455+
store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
456+
457+
return
458+
end
459+
460+
@cuda threads=32 kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
461+
d = Array(d_dev)
462+
463+
new_a = (a_layout == ColMajor) ? a : transpose(a)
464+
new_b = (b_layout == ColMajor) ? b : transpose(b)
465+
new_c = (c_layout == ColMajor) ? c : transpose(c)
466+
new_d = (d_layout == ColMajor) ? d : transpose(d)
467+
468+
if do_mac
469+
@test Float32(alpha) * Float32.(new_a) * Float32.(new_b) + beta * new_c new_d rtol=Base.rtoldefault(BFloat16)
470+
else
471+
@test Float32(alpha) * Float32.(new_a) * Float32.(new_b) new_d rtol=Base.rtoldefault(BFloat16)
472+
end
473+
end
474+
end
475+
end
476+
477+
################################################################################
478+
334479
@testset "Codegen addressing" begin
335480
@testset "Global" begin
336481
function kernel(d)

0 commit comments

Comments
 (0)