diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl index b40bbffe2d..50cb21f6b5 100644 --- a/src/device/intrinsics/wmma.jl +++ b/src/device/intrinsics/wmma.jl @@ -4,6 +4,7 @@ module WMMA import ..LLVM using ..CUDA: AS using Core: LLVMPtr +using BFloat16s: BFloat16 ################################################################################ # CONSTANTS @@ -15,6 +16,7 @@ const map_ptx_to_jl_array = Dict( "s8" => Int8, "s32" => Int32, "f16" => Float16, + "bf16" => BFloat16, "f32" => Float32 ) @@ -24,6 +26,7 @@ const map_ptx_to_jl_frag = Dict( "s8" => UInt32, "s32" => Int32, "f16" => NTuple{2, VecElement{Float16}}, + "bf16" => UInt32, "f32" => Float32 ) @@ -41,6 +44,10 @@ const map_frag_sizes = Dict( "a.f16.m16n16k16" => 8, "a.f16.m8n32k16" => 8, "a.f16.m32n8k16" => 8, + + "a.bf16.m16n16k16" => 4, + "a.bf16.m8n32k16" => 2, + "a.bf16.m32n8k16" => 8, # B "b.u8.m16n16k16" => 2, "b.u8.m8n32k16" => 4, @@ -53,6 +60,10 @@ const map_frag_sizes = Dict( "b.f16.m16n16k16" => 8, "b.f16.m8n32k16" => 8, "b.f16.m32n8k16" => 8, + + "b.bf16.m16n16k16" => 4, + "b.bf16.m8n32k16" => 8, + "b.bf16.m32n8k16" => 2, # C "c.s32.m16n16k16" => 8, "c.s32.m8n32k16" => 8, @@ -96,10 +107,13 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"] const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"] const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"] +# BFloat16 (requires Ampere+, only f32 accumulator supported) +const ldst_bf16_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["bf16"] +const wmma_bf16_ops = [(16,16,16), (32,8,16), (8,32,16)], ["bf16"], ["f32"], ["f32"] const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops, - ldst_int_ab_ops, ldst_int_cd_ops) -const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops) + ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops) +const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops) # Valid WMMA operation shapes const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)] @@ -319,12 +333,12 @@ for ops in all_wmma_ops, shape = get_hl_shape(mnk[1], mnk[2], mnk[3]) # Name of the LLVM intrinsic - # If integer/sub-byte/bit A/B types, name is determined by A/B types - if d_elem_type == "s32" + # If integer/sub-byte/bit/bf16 A/B types, name is determined by A/B types + if d_elem_type == "s32" || a_elem_type == "bf16" llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type" # Name of the Julia wrapper function func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_")) - else # Name defined by D/C types + else # f16: Name defined by D/C types llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type" # Name of the Julia wrapper function func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_")) @@ -393,6 +407,28 @@ end @generated flatten(x::typ) where typ = Expr(:tuple, flatten_recurse(typ, :x)...) @generated unflatten(::Type{typ}, x) where typ = unflatten_recurse(typ, :x, 1)[1] +# BFloat16 packing/unpacking (UInt32 contains 2x BFloat16) +@inline function unpack_bf16(x::UInt32) + lo = reinterpret(BFloat16, UInt16(x & 0xFFFF)) + hi = reinterpret(BFloat16, UInt16(x >> 16)) + return (lo, hi) +end + +@inline function pack_bf16(lo::BFloat16, hi::BFloat16) + return UInt32(reinterpret(UInt16, lo)) | (UInt32(reinterpret(UInt16, hi)) << 16) +end + +@inline function flatten_bf16(x::NTuple{N, UInt32}) where N + ntuple(i -> begin + lo, hi = unpack_bf16(x[(i+1)÷2]) + isodd(i) ? lo : hi + end, Val(2N)) +end + +@inline function unflatten_bf16(x::NTuple{N, BFloat16}) where N + ntuple(i -> pack_bf16(x[2i-1], x[2i]), Val(N÷2)) +end + ################################################################################ # HIGH LEVEL (CUDA-STYLE API) ################################################################################ @@ -513,6 +549,8 @@ const map_layout_ty_to_str = Dict( const map_num_elems = Dict( ("a", Float16) => 16, ("b", Float16) => 16, + ("a", BFloat16) => 8, + ("b", BFloat16) => 8, ("c", Float16) => 8, ("c", Float32) => 8, ("d", Float16) => 8, @@ -614,8 +652,9 @@ for mat in ["a", "b", "c"] # Name of the Julia wrapper wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_")) + _flatten = T == BFloat16 ? flatten_bf16 : flatten return quote - x = flatten($wrapper(addr, stride)) + x = $_flatten($wrapper(addr, stride)) return Fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x) end end @@ -656,19 +695,22 @@ mma b_layout = get_hl_layout(B_L) shape = get_hl_shape(M, N, K) - _, a_frag_sz, a_frag_ty, _ = get_hl_frag_info("a", A_T, shape) + _, a_frag_sz, a_frag_ty, a_arr_str = get_hl_frag_info("a", A_T, shape) _, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape) _, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape) d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape) + names = ["llvm", "wmma", "mma", a_layout, b_layout, shape] + # bf16 uses input type in intrinsic name, f16 uses d/c types + A_T === BFloat16 ? push!(names, a_arr_str) : push!(names, d_arr_str, c_arr_str) + wrapper = Symbol(join(filter(!isempty, names), "_")) - - # Name of the Julia wrapper - wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_arr_str, c_arr_str]), "_")) + a_unfl_expr = A_T === BFloat16 ? :(unflatten_bf16(a.x)) : :(unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x)) + b_unfl_expr = B_T === BFloat16 ? :(unflatten_bf16(b.x)) : :(unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x)) return quote - a_unfl = unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x) - b_unfl = unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x) + a_unfl = $a_unfl_expr + b_unfl = $b_unfl_expr c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x) x = flatten($wrapper(a_unfl, b_unfl, c_unfl)) diff --git a/test/core/device/intrinsics/wmma.jl b/test/core/device/intrinsics/wmma.jl index cc7db1c0bb..3295a62ded 100644 --- a/test/core/device/intrinsics/wmma.jl +++ b/test/core/device/intrinsics/wmma.jl @@ -2,12 +2,15 @@ if capability(device()) >= v"7.0" using CUDA.WMMA +using BFloat16s: BFloat16 + map_ptx_to_jl_frag = Dict( "u8" => reinterpret(Int32, UInt8(42) * ones(UInt8, 4))[1], "s8" => reinterpret(Int32, UInt8(42) * ones(UInt8, 4))[1], "u32" => UInt32(42), "s32" => Int32(42), "f16" => ntuple(i -> VecElement{Float16}(42), 2), + "bf16" => reinterpret(UInt32, BFloat16(42) * ones(BFloat16, 2))[1], "f32" => Float32(42) ) # Return specific matrix shape given operation configuration @@ -48,6 +51,10 @@ end startswith(elem_type, "u")) continue end + # Skip BFloat16 WMMA on pre-Ampere devices + if capability(device()) < v"8.0" && elem_type == "bf16" + continue + end shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3]) @@ -115,6 +122,10 @@ end startswith(elem_type, "u")) continue end + # Skip BFloat16 WMMA on pre-Ampere devices + if capability(device()) < v"8.0" && elem_type == "bf16" + continue + end shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3]) @@ -175,6 +186,10 @@ end startswith(ab_elem_type, "u")) continue end + # Skip BFloat16 WMMA on pre-Ampere devices + if capability(device()) < v"8.0" && ab_elem_type == "bf16" + continue + end # Type-dependent variables d_ty = CUDA.WMMA.map_ptx_to_jl_array[d_elem_type] @@ -187,9 +202,9 @@ end lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_$(shape)_global_stride_$(ab_elem_type)")) ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_$(shape)_global_stride_$(ab_elem_type)")) ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_$(shape)_global_stride_$(c_elem_type)")) - # Account for half and int/subint mma different naming conventions - # Int/subint mma functions are distinguished by the a/b element type - mma_sym = d_ty == Int32 ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") : + # Account for half and int/subint/bf16 mma different naming conventions + # Int/subint and bf16 mma functions are distinguished by the a/b element type + mma_sym = (d_ty == Int32 || ab_elem_type == "bf16") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") : Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)") mma_func = getfield(Main, mma_sym) std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)")) @@ -227,6 +242,8 @@ end # Alter test depending on a/b element Type if ab_ty == Float16 @test new_a * new_b + c ≈ Array(d_dev) rtol=Base.rtoldefault(Float16) + elseif ab_ty == BFloat16 + @test Float32.(new_a) * Float32.(new_b) + c ≈ Array(d_dev) rtol=Base.rtoldefault(BFloat16) else # Cast a and b to prevent UInt8 rollover of resultant data @test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev) end @@ -256,12 +273,20 @@ end @test WMMA.unflatten(NTuple{8, NTuple{2, Int64}}, ntuple(i -> i, 2 * 8)) == ntuple(i -> ntuple(j -> (i-1) * 2 + j, 2), 8) @test WMMA.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8) end + + @testset "BFloat16 packing/unpacking" begin + bf_vals = ntuple(i -> BFloat16(i), 8) + packed = WMMA.unflatten_bf16(bf_vals) + @test length(packed) == 4 + unpacked = WMMA.flatten_bf16(packed) + @test unpacked == bf_vals + end end ################################################################################ @testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5], - ty = [Float16, Float32] + ty = [Float16, Float32, BFloat16] @test ty(5) .* Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 * i), sz)) @test ty(5) .+ Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 + i), sz)) end @@ -331,6 +356,126 @@ end ################################################################################ +if capability(device()) >= v"8.0" +@testset "CUDA C-style API (BFloat16)" begin + @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor], + b_layout in [ColMajor, RowMajor], + c_layout in [ColMajor, RowMajor], + d_layout in [ColMajor, RowMajor], + do_mac in [true, false] + + a = rand(BFloat16, (16, 16)) + b = rand(BFloat16, (16, 16)) + c = rand(Float32, (16, 16)) + d = Array{Float32}(undef, (16, 16)) + + a_dev = CuArray(a) + b_dev = CuArray(b) + c_dev = CuArray(c) + d_dev = CuArray(d) + + # Note: BFloat16 fragment broadcasting (alpha .* a_frag) requires native bf16 + # scalar ops which aren't available on all architectures, so we skip scaling + @eval function kernel_bf16(a_dev, b_dev, c_dev, d_dev) + conf = Config{16, 16, 16, Float32} + + a_frag = load_a(pointer(a_dev), 16, $a_layout, conf) + b_frag = load_b(pointer(b_dev), 16, $b_layout, conf) + + if $do_mac + c_frag = load_c(pointer(c_dev), 16, $c_layout, conf) + else + c_frag = fill_c(Float32(0), conf) + end + + d_frag = mma(a_frag, b_frag, c_frag, conf) + + store_d(pointer(d_dev), d_frag, 16, $d_layout, conf) + + return + end + + @cuda threads=32 kernel_bf16(a_dev, b_dev, c_dev, d_dev) + d = Array(d_dev) + + new_a = (a_layout == ColMajor) ? a : transpose(a) + new_b = (b_layout == ColMajor) ? b : transpose(b) + new_c = (c_layout == ColMajor) ? c : transpose(c) + new_d = (d_layout == ColMajor) ? d : transpose(d) + + if do_mac + @test Float32.(new_a) * Float32.(new_b) + new_c ≈ new_d rtol=Base.rtoldefault(BFloat16) + else + @test Float32.(new_a) * Float32.(new_b) ≈ new_d rtol=Base.rtoldefault(BFloat16) + end + end +end +end + +# BFloat16 fragment broadcasting requires native bf16 scalar ops (CC 8.9+) +# On earlier architectures, frag[i] returns UInt32 (packed), causing type mismatch +if capability(device()) >= v"8.9" +@testset "CUDA C-style API (BFloat16 with scaling)" begin + @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor], + b_layout in [ColMajor, RowMajor], + c_layout in [ColMajor, RowMajor], + d_layout in [ColMajor, RowMajor], + do_mac in [true, false] + + a = rand(BFloat16, (16, 16)) + b = rand(BFloat16, (16, 16)) + c = rand(Float32, (16, 16)) + d = Array{Float32}(undef, (16, 16)) + + a_dev = CuArray(a) + b_dev = CuArray(b) + c_dev = CuArray(c) + d_dev = CuArray(d) + + alpha = rand(BFloat16) + beta = rand(Float32) + + @eval function kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta) + conf = Config{16, 16, 16, Float32} + + a_frag = load_a(pointer(a_dev), 16, $a_layout, conf) + b_frag = load_b(pointer(b_dev), 16, $b_layout, conf) + + if $do_mac + c_frag = load_c(pointer(c_dev), 16, $c_layout, conf) + else + c_frag = fill_c(Float32(0), conf) + end + + a_frag = alpha .* a_frag + c_frag = beta .* c_frag + + d_frag = mma(a_frag, b_frag, c_frag, conf) + + store_d(pointer(d_dev), d_frag, 16, $d_layout, conf) + + return + end + + @cuda threads=32 kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta) + d = Array(d_dev) + + new_a = (a_layout == ColMajor) ? a : transpose(a) + new_b = (b_layout == ColMajor) ? b : transpose(b) + new_c = (c_layout == ColMajor) ? c : transpose(c) + new_d = (d_layout == ColMajor) ? d : transpose(d) + + if do_mac + @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) + beta * new_c ≈ new_d rtol=Base.rtoldefault(BFloat16) + else + @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) ≈ new_d rtol=Base.rtoldefault(BFloat16) + end + end +end +end + +################################################################################ + @testset "Codegen addressing" begin @testset "Global" begin function kernel(d)