-
Notifications
You must be signed in to change notification settings - Fork 263
Add BFloat16 WMMA #3009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add BFloat16 WMMA #3009
Conversation
|
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl
index 50cb21f6b..a3ddb92d8 100644
--- a/src/device/intrinsics/wmma.jl
+++ b/src/device/intrinsics/wmma.jl
@@ -4,7 +4,7 @@ module WMMA
import ..LLVM
using ..CUDA: AS
using Core: LLVMPtr
-using BFloat16s: BFloat16
+ using BFloat16s: BFloat16
################################################################################
# CONSTANTS
@@ -16,7 +16,7 @@ const map_ptx_to_jl_array = Dict(
"s8" => Int8,
"s32" => Int32,
"f16" => Float16,
- "bf16" => BFloat16,
+ "bf16" => BFloat16,
"f32" => Float32
)
@@ -26,7 +26,7 @@ const map_ptx_to_jl_frag = Dict(
"s8" => UInt32,
"s32" => Int32,
"f16" => NTuple{2, VecElement{Float16}},
- "bf16" => UInt32,
+ "bf16" => UInt32,
"f32" => Float32
)
@@ -45,9 +45,9 @@ const map_frag_sizes = Dict(
"a.f16.m8n32k16" => 8,
"a.f16.m32n8k16" => 8,
- "a.bf16.m16n16k16" => 4,
- "a.bf16.m8n32k16" => 2,
- "a.bf16.m32n8k16" => 8,
+ "a.bf16.m16n16k16" => 4,
+ "a.bf16.m8n32k16" => 2,
+ "a.bf16.m32n8k16" => 8,
# B
"b.u8.m16n16k16" => 2,
"b.u8.m8n32k16" => 4,
@@ -61,9 +61,9 @@ const map_frag_sizes = Dict(
"b.f16.m8n32k16" => 8,
"b.f16.m32n8k16" => 8,
- "b.bf16.m16n16k16" => 4,
- "b.bf16.m8n32k16" => 8,
- "b.bf16.m32n8k16" => 2,
+ "b.bf16.m16n16k16" => 4,
+ "b.bf16.m8n32k16" => 8,
+ "b.bf16.m32n8k16" => 2,
# C
"c.s32.m16n16k16" => 8,
"c.s32.m8n32k16" => 8,
@@ -107,13 +107,14 @@ const wmma_half_ops = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f
const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"]
const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
const wmma_int_ops = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]
-# BFloat16 (requires Ampere+, only f32 accumulator supported)
-const ldst_bf16_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["bf16"]
-const wmma_bf16_ops = [(16,16,16), (32,8,16), (8,32,16)], ["bf16"], ["f32"], ["f32"]
+ # BFloat16 (requires Ampere+, only f32 accumulator supported)
+ const ldst_bf16_ab_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["a", "b"], ["bf16"]
+ const wmma_bf16_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["bf16"], ["f32"], ["f32"]
const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
- ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops)
-const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops)
+ ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops
+ )
+ const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops)
# Valid WMMA operation shapes
const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
@@ -333,12 +334,12 @@ for ops in all_wmma_ops,
shape = get_hl_shape(mnk[1], mnk[2], mnk[3])
# Name of the LLVM intrinsic
- # If integer/sub-byte/bit/bf16 A/B types, name is determined by A/B types
- if d_elem_type == "s32" || a_elem_type == "bf16"
+ # If integer/sub-byte/bit/bf16 A/B types, name is determined by A/B types
+ if d_elem_type == "s32" || a_elem_type == "bf16"
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type"
# Name of the Julia wrapper function
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_"))
- else # f16: Name defined by D/C types
+ else # f16: Name defined by D/C types
llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type"
# Name of the Julia wrapper function
func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_"))
@@ -407,27 +408,29 @@ end
@generated flatten(x::typ) where typ = Expr(:tuple, flatten_recurse(typ, :x)...)
@generated unflatten(::Type{typ}, x) where typ = unflatten_recurse(typ, :x, 1)[1]
-# BFloat16 packing/unpacking (UInt32 contains 2x BFloat16)
-@inline function unpack_bf16(x::UInt32)
- lo = reinterpret(BFloat16, UInt16(x & 0xFFFF))
- hi = reinterpret(BFloat16, UInt16(x >> 16))
- return (lo, hi)
-end
+ # BFloat16 packing/unpacking (UInt32 contains 2x BFloat16)
+ @inline function unpack_bf16(x::UInt32)
+ lo = reinterpret(BFloat16, UInt16(x & 0xFFFF))
+ hi = reinterpret(BFloat16, UInt16(x >> 16))
+ return (lo, hi)
+ end
-@inline function pack_bf16(lo::BFloat16, hi::BFloat16)
- return UInt32(reinterpret(UInt16, lo)) | (UInt32(reinterpret(UInt16, hi)) << 16)
-end
+ @inline function pack_bf16(lo::BFloat16, hi::BFloat16)
+ return UInt32(reinterpret(UInt16, lo)) | (UInt32(reinterpret(UInt16, hi)) << 16)
+ end
-@inline function flatten_bf16(x::NTuple{N, UInt32}) where N
- ntuple(i -> begin
- lo, hi = unpack_bf16(x[(i+1)÷2])
- isodd(i) ? lo : hi
- end, Val(2N))
-end
+ @inline function flatten_bf16(x::NTuple{N, UInt32}) where {N}
+ return ntuple(
+ i -> begin
+ lo, hi = unpack_bf16(x[(i + 1) ÷ 2])
+ isodd(i) ? lo : hi
+ end, Val(2N)
+ )
+ end
-@inline function unflatten_bf16(x::NTuple{N, BFloat16}) where N
- ntuple(i -> pack_bf16(x[2i-1], x[2i]), Val(N÷2))
-end
+ @inline function unflatten_bf16(x::NTuple{N, BFloat16}) where {N}
+ return ntuple(i -> pack_bf16(x[2i - 1], x[2i]), Val(N ÷ 2))
+ end
################################################################################
# HIGH LEVEL (CUDA-STYLE API)
@@ -549,8 +552,8 @@ const map_layout_ty_to_str = Dict(
const map_num_elems = Dict(
("a", Float16) => 16,
("b", Float16) => 16,
- ("a", BFloat16) => 8,
- ("b", BFloat16) => 8,
+ ("a", BFloat16) => 8,
+ ("b", BFloat16) => 8,
("c", Float16) => 8,
("c", Float32) => 8,
("d", Float16) => 8,
@@ -652,9 +655,9 @@ for mat in ["a", "b", "c"]
# Name of the Julia wrapper
wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_"))
- _flatten = T == BFloat16 ? flatten_bf16 : flatten
+ _flatten = T == BFloat16 ? flatten_bf16 : flatten
return quote
- x = $_flatten($wrapper(addr, stride))
+ x = $_flatten($wrapper(addr, stride))
return Fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x)
end
end
@@ -695,22 +698,22 @@ mma
b_layout = get_hl_layout(B_L)
shape = get_hl_shape(M, N, K)
- _, a_frag_sz, a_frag_ty, a_arr_str = get_hl_frag_info("a", A_T, shape)
+ _, a_frag_sz, a_frag_ty, a_arr_str = get_hl_frag_info("a", A_T, shape)
_, b_frag_sz, b_frag_ty, _ = get_hl_frag_info("b", B_T, shape)
_, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape)
d_num_els, _, _, d_arr_str = get_hl_frag_info("d", D_T, shape)
- names = ["llvm", "wmma", "mma", a_layout, b_layout, shape]
- # bf16 uses input type in intrinsic name, f16 uses d/c types
- A_T === BFloat16 ? push!(names, a_arr_str) : push!(names, d_arr_str, c_arr_str)
- wrapper = Symbol(join(filter(!isempty, names), "_"))
+ names = ["llvm", "wmma", "mma", a_layout, b_layout, shape]
+ # bf16 uses input type in intrinsic name, f16 uses d/c types
+ A_T === BFloat16 ? push!(names, a_arr_str) : push!(names, d_arr_str, c_arr_str)
+ wrapper = Symbol(join(filter(!isempty, names), "_"))
- a_unfl_expr = A_T === BFloat16 ? :(unflatten_bf16(a.x)) : :(unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x))
- b_unfl_expr = B_T === BFloat16 ? :(unflatten_bf16(b.x)) : :(unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x))
+ a_unfl_expr = A_T === BFloat16 ? :(unflatten_bf16(a.x)) : :(unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x))
+ b_unfl_expr = B_T === BFloat16 ? :(unflatten_bf16(b.x)) : :(unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x))
return quote
- a_unfl = $a_unfl_expr
- b_unfl = $b_unfl_expr
+ a_unfl = $a_unfl_expr
+ b_unfl = $b_unfl_expr
c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x)
x = flatten($wrapper(a_unfl, b_unfl, c_unfl))
diff --git a/test/core/device/intrinsics/wmma.jl b/test/core/device/intrinsics/wmma.jl
index 3295a62de..11338c4bb 100644
--- a/test/core/device/intrinsics/wmma.jl
+++ b/test/core/device/intrinsics/wmma.jl
@@ -2,7 +2,7 @@ if capability(device()) >= v"7.0"
using CUDA.WMMA
-using BFloat16s: BFloat16
+ using BFloat16s: BFloat16
map_ptx_to_jl_frag = Dict(
"u8" => reinterpret(Int32, UInt8(42) * ones(UInt8, 4))[1],
@@ -10,7 +10,7 @@ map_ptx_to_jl_frag = Dict(
"u32" => UInt32(42),
"s32" => Int32(42),
"f16" => ntuple(i -> VecElement{Float16}(42), 2),
- "bf16" => reinterpret(UInt32, BFloat16(42) * ones(BFloat16, 2))[1],
+ "bf16" => reinterpret(UInt32, BFloat16(42) * ones(BFloat16, 2))[1],
"f32" => Float32(42)
)
# Return specific matrix shape given operation configuration
@@ -51,10 +51,10 @@ end
startswith(elem_type, "u"))
continue
end
- # Skip BFloat16 WMMA on pre-Ampere devices
- if capability(device()) < v"8.0" && elem_type == "bf16"
- continue
- end
+ # Skip BFloat16 WMMA on pre-Ampere devices
+ if capability(device()) < v"8.0" && elem_type == "bf16"
+ continue
+ end
shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])
@@ -122,10 +122,10 @@ end
startswith(elem_type, "u"))
continue
end
- # Skip BFloat16 WMMA on pre-Ampere devices
- if capability(device()) < v"8.0" && elem_type == "bf16"
- continue
- end
+ # Skip BFloat16 WMMA on pre-Ampere devices
+ if capability(device()) < v"8.0" && elem_type == "bf16"
+ continue
+ end
shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])
@@ -186,10 +186,10 @@ end
startswith(ab_elem_type, "u"))
continue
end
- # Skip BFloat16 WMMA on pre-Ampere devices
- if capability(device()) < v"8.0" && ab_elem_type == "bf16"
- continue
- end
+ # Skip BFloat16 WMMA on pre-Ampere devices
+ if capability(device()) < v"8.0" && ab_elem_type == "bf16"
+ continue
+ end
# Type-dependent variables
d_ty = CUDA.WMMA.map_ptx_to_jl_array[d_elem_type]
@@ -202,9 +202,9 @@ end
lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_$(shape)_global_stride_$(ab_elem_type)"))
ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_$(shape)_global_stride_$(ab_elem_type)"))
ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_$(shape)_global_stride_$(c_elem_type)"))
- # Account for half and int/subint/bf16 mma different naming conventions
- # Int/subint and bf16 mma functions are distinguished by the a/b element type
- mma_sym = (d_ty == Int32 || ab_elem_type == "bf16") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
+ # Account for half and int/subint/bf16 mma different naming conventions
+ # Int/subint and bf16 mma functions are distinguished by the a/b element type
+ mma_sym = (d_ty == Int32 || ab_elem_type == "bf16") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)")
mma_func = getfield(Main, mma_sym)
std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)"))
@@ -242,8 +242,8 @@ end
# Alter test depending on a/b element Type
if ab_ty == Float16
@test new_a * new_b + c ≈ Array(d_dev) rtol=Base.rtoldefault(Float16)
- elseif ab_ty == BFloat16
- @test Float32.(new_a) * Float32.(new_b) + c ≈ Array(d_dev) rtol=Base.rtoldefault(BFloat16)
+ elseif ab_ty == BFloat16
+ @test Float32.(new_a) * Float32.(new_b) + c ≈ Array(d_dev) rtol = Base.rtoldefault(BFloat16)
else # Cast a and b to prevent UInt8 rollover of resultant data
@test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev)
end
@@ -274,19 +274,19 @@ end
@test WMMA.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)
end
- @testset "BFloat16 packing/unpacking" begin
- bf_vals = ntuple(i -> BFloat16(i), 8)
- packed = WMMA.unflatten_bf16(bf_vals)
- @test length(packed) == 4
- unpacked = WMMA.flatten_bf16(packed)
- @test unpacked == bf_vals
- end
+ @testset "BFloat16 packing/unpacking" begin
+ bf_vals = ntuple(i -> BFloat16(i), 8)
+ packed = WMMA.unflatten_bf16(bf_vals)
+ @test length(packed) == 4
+ unpacked = WMMA.flatten_bf16(packed)
+ @test unpacked == bf_vals
+ end
end
################################################################################
@testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5],
- ty = [Float16, Float32, BFloat16]
+ ty in [Float16, Float32, BFloat16]
@test ty(5) .* Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 * i), sz))
@test ty(5) .+ Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 + i), sz))
end
@@ -356,125 +356,125 @@ end
################################################################################
-if capability(device()) >= v"8.0"
-@testset "CUDA C-style API (BFloat16)" begin
- @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
- b_layout in [ColMajor, RowMajor],
- c_layout in [ColMajor, RowMajor],
- d_layout in [ColMajor, RowMajor],
- do_mac in [true, false]
-
- a = rand(BFloat16, (16, 16))
- b = rand(BFloat16, (16, 16))
- c = rand(Float32, (16, 16))
- d = Array{Float32}(undef, (16, 16))
-
- a_dev = CuArray(a)
- b_dev = CuArray(b)
- c_dev = CuArray(c)
- d_dev = CuArray(d)
-
- # Note: BFloat16 fragment broadcasting (alpha .* a_frag) requires native bf16
- # scalar ops which aren't available on all architectures, so we skip scaling
- @eval function kernel_bf16(a_dev, b_dev, c_dev, d_dev)
- conf = Config{16, 16, 16, Float32}
-
- a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
- b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
-
- if $do_mac
- c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
- else
- c_frag = fill_c(Float32(0), conf)
- end
+ if capability(device()) >= v"8.0"
+ @testset "CUDA C-style API (BFloat16)" begin
+ @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
+ b_layout in [ColMajor, RowMajor],
+ c_layout in [ColMajor, RowMajor],
+ d_layout in [ColMajor, RowMajor],
+ do_mac in [true, false]
+
+ a = rand(BFloat16, (16, 16))
+ b = rand(BFloat16, (16, 16))
+ c = rand(Float32, (16, 16))
+ d = Array{Float32}(undef, (16, 16))
+
+ a_dev = CuArray(a)
+ b_dev = CuArray(b)
+ c_dev = CuArray(c)
+ d_dev = CuArray(d)
+
+ # Note: BFloat16 fragment broadcasting (alpha .* a_frag) requires native bf16
+ # scalar ops which aren't available on all architectures, so we skip scaling
+ @eval function kernel_bf16(a_dev, b_dev, c_dev, d_dev)
+ conf = Config{16, 16, 16, Float32}
+
+ a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
+ b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
+
+ if $do_mac
+ c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
+ else
+ c_frag = fill_c(Float32(0), conf)
+ end
- d_frag = mma(a_frag, b_frag, c_frag, conf)
+ d_frag = mma(a_frag, b_frag, c_frag, conf)
- store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
+ store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
- return
- end
+ return
+ end
- @cuda threads=32 kernel_bf16(a_dev, b_dev, c_dev, d_dev)
- d = Array(d_dev)
+ @cuda threads = 32 kernel_bf16(a_dev, b_dev, c_dev, d_dev)
+ d = Array(d_dev)
- new_a = (a_layout == ColMajor) ? a : transpose(a)
- new_b = (b_layout == ColMajor) ? b : transpose(b)
- new_c = (c_layout == ColMajor) ? c : transpose(c)
- new_d = (d_layout == ColMajor) ? d : transpose(d)
+ new_a = (a_layout == ColMajor) ? a : transpose(a)
+ new_b = (b_layout == ColMajor) ? b : transpose(b)
+ new_c = (c_layout == ColMajor) ? c : transpose(c)
+ new_d = (d_layout == ColMajor) ? d : transpose(d)
- if do_mac
- @test Float32.(new_a) * Float32.(new_b) + new_c ≈ new_d rtol=Base.rtoldefault(BFloat16)
- else
- @test Float32.(new_a) * Float32.(new_b) ≈ new_d rtol=Base.rtoldefault(BFloat16)
+ if do_mac
+ @test Float32.(new_a) * Float32.(new_b) + new_c ≈ new_d rtol = Base.rtoldefault(BFloat16)
+ else
+ @test Float32.(new_a) * Float32.(new_b) ≈ new_d rtol = Base.rtoldefault(BFloat16)
+ end
+ end
end
end
-end
-end
-
-# BFloat16 fragment broadcasting requires native bf16 scalar ops (CC 8.9+)
-# On earlier architectures, frag[i] returns UInt32 (packed), causing type mismatch
-if capability(device()) >= v"8.9"
-@testset "CUDA C-style API (BFloat16 with scaling)" begin
- @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
- b_layout in [ColMajor, RowMajor],
- c_layout in [ColMajor, RowMajor],
- d_layout in [ColMajor, RowMajor],
- do_mac in [true, false]
-
- a = rand(BFloat16, (16, 16))
- b = rand(BFloat16, (16, 16))
- c = rand(Float32, (16, 16))
- d = Array{Float32}(undef, (16, 16))
-
- a_dev = CuArray(a)
- b_dev = CuArray(b)
- c_dev = CuArray(c)
- d_dev = CuArray(d)
- alpha = rand(BFloat16)
- beta = rand(Float32)
-
- @eval function kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
- conf = Config{16, 16, 16, Float32}
-
- a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
- b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
-
- if $do_mac
- c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
- else
- c_frag = fill_c(Float32(0), conf)
- end
+ # BFloat16 fragment broadcasting requires native bf16 scalar ops (CC 8.9+)
+ # On earlier architectures, frag[i] returns UInt32 (packed), causing type mismatch
+ if capability(device()) >= v"8.9"
+ @testset "CUDA C-style API (BFloat16 with scaling)" begin
+ @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
+ b_layout in [ColMajor, RowMajor],
+ c_layout in [ColMajor, RowMajor],
+ d_layout in [ColMajor, RowMajor],
+ do_mac in [true, false]
+
+ a = rand(BFloat16, (16, 16))
+ b = rand(BFloat16, (16, 16))
+ c = rand(Float32, (16, 16))
+ d = Array{Float32}(undef, (16, 16))
+
+ a_dev = CuArray(a)
+ b_dev = CuArray(b)
+ c_dev = CuArray(c)
+ d_dev = CuArray(d)
+
+ alpha = rand(BFloat16)
+ beta = rand(Float32)
+
+ @eval function kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
+ conf = Config{16, 16, 16, Float32}
+
+ a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
+ b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
+
+ if $do_mac
+ c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
+ else
+ c_frag = fill_c(Float32(0), conf)
+ end
- a_frag = alpha .* a_frag
- c_frag = beta .* c_frag
+ a_frag = alpha .* a_frag
+ c_frag = beta .* c_frag
- d_frag = mma(a_frag, b_frag, c_frag, conf)
+ d_frag = mma(a_frag, b_frag, c_frag, conf)
- store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
+ store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
- return
- end
+ return
+ end
- @cuda threads=32 kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
- d = Array(d_dev)
+ @cuda threads = 32 kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
+ d = Array(d_dev)
- new_a = (a_layout == ColMajor) ? a : transpose(a)
- new_b = (b_layout == ColMajor) ? b : transpose(b)
- new_c = (c_layout == ColMajor) ? c : transpose(c)
- new_d = (d_layout == ColMajor) ? d : transpose(d)
+ new_a = (a_layout == ColMajor) ? a : transpose(a)
+ new_b = (b_layout == ColMajor) ? b : transpose(b)
+ new_c = (c_layout == ColMajor) ? c : transpose(c)
+ new_d = (d_layout == ColMajor) ? d : transpose(d)
- if do_mac
- @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) + beta * new_c ≈ new_d rtol=Base.rtoldefault(BFloat16)
- else
- @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) ≈ new_d rtol=Base.rtoldefault(BFloat16)
+ if do_mac
+ @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) + beta * new_c ≈ new_d rtol = Base.rtoldefault(BFloat16)
+ else
+ @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) ≈ new_d rtol = Base.rtoldefault(BFloat16)
+ end
+ end
end
end
-end
-end
-################################################################################
+ ################################################################################
@testset "Codegen addressing" begin
@testset "Global" begin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA.jl Benchmarks
Details
| Benchmark suite | Current: 4a9c9f9 | Previous: 5d9474a | Ratio |
|---|---|---|---|
latency/precompile |
55528201490 ns |
55510377029.5 ns |
1.00 |
latency/ttfp |
7811188183.5 ns |
7790703567 ns |
1.00 |
latency/import |
4133329528 ns |
4122189304 ns |
1.00 |
integration/volumerhs |
9627263 ns |
9624973 ns |
1.00 |
integration/byval/slices=1 |
147065 ns |
147064 ns |
1.00 |
integration/byval/slices=3 |
426264 ns |
425893 ns |
1.00 |
integration/byval/reference |
145174 ns |
145082 ns |
1.00 |
integration/byval/slices=2 |
286633 ns |
286384 ns |
1.00 |
integration/cudadevrt |
103611 ns |
103602 ns |
1.00 |
kernel/indexing |
14234 ns |
14225 ns |
1.00 |
kernel/indexing_checked |
15096 ns |
14969 ns |
1.01 |
kernel/occupancy |
795.8971962616822 ns |
732.5227272727273 ns |
1.09 |
kernel/launch |
2218.3333333333335 ns |
2249.4444444444443 ns |
0.99 |
kernel/rand |
17364 ns |
18642 ns |
0.93 |
array/reverse/1d |
20091 ns |
19990 ns |
1.01 |
array/reverse/2dL_inplace |
66806 ns |
66917 ns |
1.00 |
array/reverse/1dL |
70286 ns |
70158 ns |
1.00 |
array/reverse/2d |
21785 ns |
21954 ns |
0.99 |
array/reverse/1d_inplace |
9770 ns |
9677 ns |
1.01 |
array/reverse/2d_inplace |
13317 ns |
11077 ns |
1.20 |
array/reverse/2dL |
73973 ns |
74051.5 ns |
1.00 |
array/reverse/1dL_inplace |
66896 ns |
66880 ns |
1.00 |
array/copy |
20472 ns |
20660 ns |
0.99 |
array/iteration/findall/int |
157793 ns |
158373 ns |
1.00 |
array/iteration/findall/bool |
139983 ns |
140139 ns |
1.00 |
array/iteration/findfirst/int |
160728 ns |
161271 ns |
1.00 |
array/iteration/findfirst/bool |
161750 ns |
162049 ns |
1.00 |
array/iteration/scalar |
72145 ns |
72812.5 ns |
0.99 |
array/iteration/logical |
215522.5 ns |
216894.5 ns |
0.99 |
array/iteration/findmin/1d |
50436 ns |
50981 ns |
0.99 |
array/iteration/findmin/2d |
96311 ns |
96704 ns |
1.00 |
array/reductions/reduce/Int64/1d |
43462 ns |
43491 ns |
1.00 |
array/reductions/reduce/Int64/dims=1 |
55142 ns |
52642.5 ns |
1.05 |
array/reductions/reduce/Int64/dims=2 |
61310 ns |
61484 ns |
1.00 |
array/reductions/reduce/Int64/dims=1L |
89219 ns |
88879 ns |
1.00 |
array/reductions/reduce/Int64/dims=2L |
87813 ns |
87977 ns |
1.00 |
array/reductions/reduce/Float32/1d |
36984 ns |
37248.5 ns |
0.99 |
array/reductions/reduce/Float32/dims=1 |
43469 ns |
43278 ns |
1.00 |
array/reductions/reduce/Float32/dims=2 |
59972 ns |
60066 ns |
1.00 |
array/reductions/reduce/Float32/dims=1L |
52580 ns |
52282 ns |
1.01 |
array/reductions/reduce/Float32/dims=2L |
72503 ns |
72365.5 ns |
1.00 |
array/reductions/mapreduce/Int64/1d |
43417 ns |
43561 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=1 |
44553 ns |
44306 ns |
1.01 |
array/reductions/mapreduce/Int64/dims=2 |
61764 ns |
61482 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=1L |
89214 ns |
89001 ns |
1.00 |
array/reductions/mapreduce/Int64/dims=2L |
88378 ns |
88320 ns |
1.00 |
array/reductions/mapreduce/Float32/1d |
38348 ns |
38092.5 ns |
1.01 |
array/reductions/mapreduce/Float32/dims=1 |
41977.5 ns |
41962 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=2 |
59781 ns |
60039 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=1L |
52818 ns |
52636 ns |
1.00 |
array/reductions/mapreduce/Float32/dims=2L |
72504 ns |
72310 ns |
1.00 |
array/broadcast |
20204 ns |
20127 ns |
1.00 |
array/copyto!/gpu_to_gpu |
12855 ns |
12738 ns |
1.01 |
array/copyto!/cpu_to_gpu |
215351 ns |
217857 ns |
0.99 |
array/copyto!/gpu_to_cpu |
285712 ns |
287088 ns |
1.00 |
array/accumulate/Int64/1d |
125156 ns |
124778 ns |
1.00 |
array/accumulate/Int64/dims=1 |
83876 ns |
83708 ns |
1.00 |
array/accumulate/Int64/dims=2 |
157963 ns |
158367 ns |
1.00 |
array/accumulate/Int64/dims=1L |
1720455 ns |
1710164 ns |
1.01 |
array/accumulate/Int64/dims=2L |
967823 ns |
967254 ns |
1.00 |
array/accumulate/Float32/1d |
109198 ns |
109314 ns |
1.00 |
array/accumulate/Float32/dims=1 |
81088 ns |
80184 ns |
1.01 |
array/accumulate/Float32/dims=2 |
148055.5 ns |
147922 ns |
1.00 |
array/accumulate/Float32/dims=1L |
1618605 ns |
1618786 ns |
1.00 |
array/accumulate/Float32/dims=2L |
700756.5 ns |
698724 ns |
1.00 |
array/construct |
1247 ns |
1295.5 ns |
0.96 |
array/random/randn/Float32 |
43284 ns |
47861 ns |
0.90 |
array/random/randn!/Float32 |
25002 ns |
24875 ns |
1.01 |
array/random/rand!/Int64 |
27480 ns |
27408 ns |
1.00 |
array/random/rand!/Float32 |
8852.333333333334 ns |
8909.666666666666 ns |
0.99 |
array/random/rand/Int64 |
30038 ns |
30055 ns |
1.00 |
array/random/rand/Float32 |
13305 ns |
13184 ns |
1.01 |
array/permutedims/4d |
55999 ns |
55109 ns |
1.02 |
array/permutedims/2d |
54238.5 ns |
53832 ns |
1.01 |
array/permutedims/3d |
55209 ns |
54841 ns |
1.01 |
array/sorting/1d |
2757318 ns |
2757534 ns |
1.00 |
array/sorting/by |
3368851 ns |
3344541 ns |
1.01 |
array/sorting/2d |
1084980 ns |
1081521 ns |
1.00 |
cuda/synchronization/stream/auto |
1041.9 ns |
1036.5 ns |
1.01 |
cuda/synchronization/stream/nonblocking |
7094.5 ns |
7410.8 ns |
0.96 |
cuda/synchronization/stream/blocking |
812.4375 ns |
820.6336633663366 ns |
0.99 |
cuda/synchronization/context/auto |
1189.2 ns |
1154.3 ns |
1.03 |
cuda/synchronization/context/nonblocking |
7663 ns |
7124.4 ns |
1.08 |
cuda/synchronization/context/blocking |
912.2083333333334 ns |
887.4107142857143 ns |
1.03 |
This comment was automatically generated by workflow using github-action-benchmark.
|
Tests fail on CC 8.0. I've narrowed it down locally to the C-style API for BFloat16 failing until CC 8.6, starting to pass on CC 8.9, as well as passing on CC 12.0. EDIT: Fixed. The |
|
Testing the example from #1425: julia> function kernel_wmma_bf16_lowlevel(a_dev, b_dev, c_dev, d_dev)
a_frag = WMMA.llvm_wmma_load_a_col_m16n16k16_global_stride_bf16(pointer(a_dev), 16)
b_frag = WMMA.llvm_wmma_load_b_col_m16n16k16_global_stride_bf16(pointer(b_dev), 16)
c_frag = WMMA.llvm_wmma_load_c_col_m16n16k16_global_stride_f32(pointer(c_dev), 16)
d_frag = WMMA.llvm_wmma_mma_col_col_m16n16k16_bf16(a_frag, b_frag, c_frag)
WMMA.llvm_wmma_store_d_col_m16n16k16_global_stride_f32(pointer(d_dev), d_frag, 16)
return nothing
end
kernel_wmma_bf16_lowlevel (generic function with 1 method)
julia> function call_kernel()
m = n = k = 16
dtype_a = dtype_b = CUDA.BFloat16
dtype_c = dtype_d = Float32
d_a = CUDA.rand(dtype_a, m, k)
d_b = CUDA.rand(dtype_b, k, n)
d_c = CUDA.rand(dtype_c, m, n)
d_d = CUDA.zeros(dtype_d, m, n)
CUDA.@sync @cuda threads=32 kernel_wmma_bf16_lowlevel(d_a, d_b, d_c, d_d)
expected = Float32.(Array(d_a)) * Float32.(Array(d_b)) + Array(d_c)
actual = Array(d_d)
return (; err=maximum(abs, expected - actual), expected, actual)
end
call_kernel (generic function with 1 method)
julia> call_kernel()
(err = 4.7683716f-7, expected = Float32[4.347838 3.3425827 … 4.157455 3.3207057; 6.44958 4.2143664 … 5.513032 3.220647; … ; 3.962544 2.700366 … 4.6662664 2.5463972; 3.6238153 3.7428212 … 4.1363244 2.9692328], actual = Float32[4.347838 3.3425827 … 4.157455 3.3207054; 6.4495797 4.214366 … 5.513032 3.2206469; … ; 3.9625437 2.700366 … 4.666266 2.5463972; 3.6238153 3.7428212 … 4.1363244 2.9692328])Carsten's errors from 4 years ago appear to have come from BFloat16 initialization issues in Julia 1.8. BFloat16 seems generally stable now on 1.12. |
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## master #3009 +/- ##
===========================================
+ Coverage 76.53% 89.33% +12.80%
===========================================
Files 148 148
Lines 12860 12947 +87
===========================================
+ Hits 9842 11566 +1724
+ Misses 3018 1381 -1637 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
LGTM, thanks. Out of curiosity, what are you using WMMA for? It's not particularly great, hard to maintain across architectures, and I'm hoping we can instead rely on cuTile to target tensor cores more idiomatically instead. |
|
@pxl-th and I (mostly him) have been working on https://github.com/FluxML/NNop.jl. I was curious to see if we could utilize tensor cores (FluxML/NNop.jl#17) and he made the effort to get it working for CUDA in FluxML/NNop.jl#26 using WMMA, but without this PR, switching would mean no BFloat16. I'm currently doing nanoGPT speedrunning in Julia, and am thus trying to maximize FLOPs and tokens/sec. I am indeed aware of #2991, and I'm eagerly awaiting!🙏 Would be happy to test once ready (or any time, really)! |
Supercedes #1425
BFloat16 WMMA intrinsics (CC 8.0+) follow a different naming scheme than Float16. Like integer WMMA, BFloat16 MMA intrinsics are named by input type (.bf16) rather than accumulator types (.f32.f32), and only seems to support Float32 for accumulation.
BFloat16 fragments are also packed differently, with two BFloat16 values per UInt32 (vs Float16's <2 x half> vectors), requiring custom flatten_bf16/unflatten_bf16 functions. I tried a simpler
reinterpretbut that was too dynamic I suppose.