Skip to content

Conversation

@AntonOresten
Copy link
Contributor

Supercedes #1425

BFloat16 WMMA intrinsics (CC 8.0+) follow a different naming scheme than Float16. Like integer WMMA, BFloat16 MMA intrinsics are named by input type (.bf16) rather than accumulator types (.f32.f32), and only seems to support Float32 for accumulation.

BFloat16 fragments are also packed differently, with two BFloat16 values per UInt32 (vs Float16's <2 x half> vectors), requiring custom flatten_bf16/unflatten_bf16 functions. I tried a simpler reinterpret but that was too dynamic I suppose.

@github-actions
Copy link
Contributor

github-actions bot commented Jan 3, 2026

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/device/intrinsics/wmma.jl b/src/device/intrinsics/wmma.jl
index 50cb21f6b..a3ddb92d8 100644
--- a/src/device/intrinsics/wmma.jl
+++ b/src/device/intrinsics/wmma.jl
@@ -4,7 +4,7 @@ module WMMA
 import ..LLVM
 using ..CUDA: AS
 using Core: LLVMPtr
-using BFloat16s: BFloat16
+    using BFloat16s: BFloat16
 
 ################################################################################
 # CONSTANTS
@@ -16,7 +16,7 @@ const map_ptx_to_jl_array = Dict(
                                  "s8"  => Int8,
                                  "s32" => Int32,
                                  "f16" => Float16,
-                                 "bf16" => BFloat16,
+        "bf16" => BFloat16,
                                  "f32" => Float32
                                 )
 
@@ -26,7 +26,7 @@ const map_ptx_to_jl_frag = Dict(
                                 "s8"  => UInt32,
                                 "s32" => Int32,
                                 "f16" => NTuple{2, VecElement{Float16}},
-                                "bf16" => UInt32,
+        "bf16" => UInt32,
                                 "f32" => Float32
                                )
 
@@ -45,9 +45,9 @@ const map_frag_sizes = Dict(
                             "a.f16.m8n32k16"  => 8,
                             "a.f16.m32n8k16"  => 8,
 
-                            "a.bf16.m16n16k16" => 4,
-                            "a.bf16.m8n32k16"  => 2,
-                            "a.bf16.m32n8k16"  => 8,
+        "a.bf16.m16n16k16" => 4,
+        "a.bf16.m8n32k16" => 2,
+        "a.bf16.m32n8k16" => 8,
                             # B
                             "b.u8.m16n16k16"  => 2,
                             "b.u8.m8n32k16"   => 4,
@@ -61,9 +61,9 @@ const map_frag_sizes = Dict(
                             "b.f16.m8n32k16"  => 8,
                             "b.f16.m32n8k16"  => 8,
 
-                            "b.bf16.m16n16k16" => 4,
-                            "b.bf16.m8n32k16"  => 8,
-                            "b.bf16.m32n8k16"  => 2,
+        "b.bf16.m16n16k16" => 4,
+        "b.bf16.m8n32k16" => 8,
+        "b.bf16.m32n8k16" => 2,
                             # C
                             "c.s32.m16n16k16" => 8,
                             "c.s32.m8n32k16"  => 8,
@@ -107,13 +107,14 @@ const wmma_half_ops    = [(16,16,16), (32,8,16), (8,32,16)], ["f16"], ["f16", "f
 const ldst_int_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["u8", "s8"]
 const ldst_int_cd_ops = [(16,16,16), (32,8,16), (8,32,16)], ["c", "d"], ["s32"]
 const wmma_int_ops    = [(16,16,16), (32,8,16), (8,32,16)], ["s8", "u8"], ["s32"], ["s32"]
-# BFloat16 (requires Ampere+, only f32 accumulator supported)
-const ldst_bf16_ab_ops = [(16,16,16), (32,8,16), (8,32,16)], ["a", "b"], ["bf16"]
-const wmma_bf16_ops    = [(16,16,16), (32,8,16), (8,32,16)], ["bf16"], ["f32"], ["f32"]
+    # BFloat16 (requires Ampere+, only f32 accumulator supported)
+    const ldst_bf16_ab_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["a", "b"], ["bf16"]
+    const wmma_bf16_ops = [(16, 16, 16), (32, 8, 16), (8, 32, 16)], ["bf16"], ["f32"], ["f32"]
 
 const all_ldst_ops = vcat(ldst_half_ab_ops, ldst_half_cd_ops,
-                          ldst_int_ab_ops,  ldst_int_cd_ops, ldst_bf16_ab_ops)
-const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops)
+        ldst_int_ab_ops, ldst_int_cd_ops, ldst_bf16_ab_ops
+    )
+    const all_wmma_ops = vcat(wmma_half_ops, wmma_int_ops, wmma_bf16_ops)
 
 # Valid WMMA operation shapes
 const valid_shapes = [(16, 16, 16), (32, 8, 16), (8, 32, 16)]
@@ -333,12 +334,12 @@ for ops in all_wmma_ops,
     shape = get_hl_shape(mnk[1], mnk[2], mnk[3])
 
     # Name of the LLVM intrinsic
-    # If integer/sub-byte/bit/bf16 A/B types, name is determined by A/B types
-    if d_elem_type == "s32" || a_elem_type == "bf16"
+        # If integer/sub-byte/bit/bf16 A/B types, name is determined by A/B types
+        if d_elem_type == "s32" || a_elem_type == "bf16"
         llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$a_elem_type"
         # Name of the Julia wrapper function
         func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, a_elem_type]), "_"))
-    else # f16: Name defined by D/C types
+        else # f16: Name defined by D/C types
         llvm_intr = "llvm.nvvm.wmma.$shape.mma.$a_layout.$b_layout.$d_elem_type.$c_elem_type"
         # Name of the Julia wrapper function
         func_name = Symbol(join(filter(!isempty, ["llvm", "wmma", "mma", a_layout, b_layout, shape, d_elem_type, c_elem_type]), "_"))
@@ -407,27 +408,29 @@ end
 @generated flatten(x::typ) where typ = Expr(:tuple, flatten_recurse(typ, :x)...)
 @generated unflatten(::Type{typ}, x) where typ = unflatten_recurse(typ, :x, 1)[1]
 
-# BFloat16 packing/unpacking (UInt32 contains 2x BFloat16)
-@inline function unpack_bf16(x::UInt32)
-    lo = reinterpret(BFloat16, UInt16(x & 0xFFFF))
-    hi = reinterpret(BFloat16, UInt16(x >> 16))
-    return (lo, hi)
-end
+    # BFloat16 packing/unpacking (UInt32 contains 2x BFloat16)
+    @inline function unpack_bf16(x::UInt32)
+        lo = reinterpret(BFloat16, UInt16(x & 0xFFFF))
+        hi = reinterpret(BFloat16, UInt16(x >> 16))
+        return (lo, hi)
+    end
 
-@inline function pack_bf16(lo::BFloat16, hi::BFloat16)
-    return UInt32(reinterpret(UInt16, lo)) | (UInt32(reinterpret(UInt16, hi)) << 16)
-end
+    @inline function pack_bf16(lo::BFloat16, hi::BFloat16)
+        return UInt32(reinterpret(UInt16, lo)) | (UInt32(reinterpret(UInt16, hi)) << 16)
+    end
 
-@inline function flatten_bf16(x::NTuple{N, UInt32}) where N
-    ntuple(i -> begin
-        lo, hi = unpack_bf16(x[(i+1)÷2])
-        isodd(i) ? lo : hi
-    end, Val(2N))
-end
+    @inline function flatten_bf16(x::NTuple{N, UInt32}) where {N}
+        return ntuple(
+            i -> begin
+                lo, hi = unpack_bf16(x[(i + 1) ÷ 2])
+                isodd(i) ? lo : hi
+            end, Val(2N)
+        )
+    end
 
-@inline function unflatten_bf16(x::NTuple{N, BFloat16}) where N
-    ntuple(i -> pack_bf16(x[2i-1], x[2i]), Val(N÷2))
-end
+    @inline function unflatten_bf16(x::NTuple{N, BFloat16}) where {N}
+        return ntuple(i -> pack_bf16(x[2i - 1], x[2i]), Val(N ÷ 2))
+    end
 
 ################################################################################
 # HIGH LEVEL (CUDA-STYLE API)
@@ -549,8 +552,8 @@ const map_layout_ty_to_str = Dict(
 const map_num_elems = Dict(
                            ("a", Float16) => 16,
                            ("b", Float16) => 16,
-                           ("a", BFloat16) => 8,
-                           ("b", BFloat16) => 8,
+        ("a", BFloat16) => 8,
+        ("b", BFloat16) => 8,
                            ("c", Float16) => 8,
                            ("c", Float32) => 8,
                            ("d", Float16) => 8,
@@ -652,9 +655,9 @@ for mat in ["a", "b", "c"]
         # Name of the Julia wrapper
         wrapper = Symbol(join(filter(!isempty, ["llvm", "wmma", "load", $mat, layout, shape, as_str, "stride", arr_str]), "_"))
 
-        _flatten = T == BFloat16 ? flatten_bf16 : flatten
+            _flatten = T == BFloat16 ? flatten_bf16 : flatten
         return quote
-            x = $_flatten($wrapper(addr, stride))
+                x = $_flatten($wrapper(addr, stride))
             return Fragment{$M, $N, $K, $num_els, $T, $L_ret, $U}(x)
         end
     end
@@ -695,22 +698,22 @@ mma
     b_layout = get_hl_layout(B_L)
     shape = get_hl_shape(M, N, K)
 
-    _, a_frag_sz, a_frag_ty, a_arr_str = get_hl_frag_info("a", A_T, shape)
+        _, a_frag_sz, a_frag_ty, a_arr_str = get_hl_frag_info("a", A_T, shape)
     _, b_frag_sz, b_frag_ty, _         = get_hl_frag_info("b", B_T, shape)
     _, c_frag_sz, c_frag_ty, c_arr_str = get_hl_frag_info("c", C_T, shape)
     d_num_els, _, _, d_arr_str         = get_hl_frag_info("d", D_T, shape)
 
-    names = ["llvm", "wmma", "mma", a_layout, b_layout, shape]
-    # bf16 uses input type in intrinsic name, f16 uses d/c types
-    A_T === BFloat16 ? push!(names, a_arr_str) : push!(names, d_arr_str, c_arr_str)
-    wrapper = Symbol(join(filter(!isempty, names), "_"))
+        names = ["llvm", "wmma", "mma", a_layout, b_layout, shape]
+        # bf16 uses input type in intrinsic name, f16 uses d/c types
+        A_T === BFloat16 ? push!(names, a_arr_str) : push!(names, d_arr_str, c_arr_str)
+        wrapper = Symbol(join(filter(!isempty, names), "_"))
 
-    a_unfl_expr = A_T === BFloat16 ? :(unflatten_bf16(a.x)) : :(unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x))
-    b_unfl_expr = B_T === BFloat16 ? :(unflatten_bf16(b.x)) : :(unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x))
+        a_unfl_expr = A_T === BFloat16 ? :(unflatten_bf16(a.x)) : :(unflatten(NTuple{$a_frag_sz, $a_frag_ty}, a.x))
+        b_unfl_expr = B_T === BFloat16 ? :(unflatten_bf16(b.x)) : :(unflatten(NTuple{$b_frag_sz, $b_frag_ty}, b.x))
 
     return quote
-        a_unfl = $a_unfl_expr
-        b_unfl = $b_unfl_expr
+            a_unfl = $a_unfl_expr
+            b_unfl = $b_unfl_expr
         c_unfl = unflatten(NTuple{$c_frag_sz, $c_frag_ty}, c.x)
 
         x = flatten($wrapper(a_unfl, b_unfl, c_unfl))
diff --git a/test/core/device/intrinsics/wmma.jl b/test/core/device/intrinsics/wmma.jl
index 3295a62de..11338c4bb 100644
--- a/test/core/device/intrinsics/wmma.jl
+++ b/test/core/device/intrinsics/wmma.jl
@@ -2,7 +2,7 @@ if capability(device()) >= v"7.0"
 
 using CUDA.WMMA
 
-using BFloat16s: BFloat16
+    using BFloat16s: BFloat16
 
 map_ptx_to_jl_frag = Dict(
                             "u8"  => reinterpret(Int32, UInt8(42) * ones(UInt8, 4))[1],
@@ -10,7 +10,7 @@ map_ptx_to_jl_frag = Dict(
                             "u32" => UInt32(42),
                             "s32" => Int32(42),
                             "f16" => ntuple(i -> VecElement{Float16}(42), 2),
-                            "bf16" => reinterpret(UInt32, BFloat16(42) * ones(BFloat16, 2))[1],
+        "bf16" => reinterpret(UInt32, BFloat16(42) * ones(BFloat16, 2))[1],
                             "f32" => Float32(42)
                             )
 # Return specific matrix shape given operation configuration
@@ -51,10 +51,10 @@ end
                                                  startswith(elem_type, "u"))
                 continue
             end
-            # Skip BFloat16 WMMA on pre-Ampere devices
-            if capability(device()) < v"8.0" && elem_type == "bf16"
-                continue
-            end
+                # Skip BFloat16 WMMA on pre-Ampere devices
+                if capability(device()) < v"8.0" && elem_type == "bf16"
+                    continue
+                end
 
             shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])
 
@@ -122,10 +122,10 @@ end
                                                  startswith(elem_type, "u"))
                 continue
             end
-            # Skip BFloat16 WMMA on pre-Ampere devices
-            if capability(device()) < v"8.0" && elem_type == "bf16"
-                continue
-            end
+                # Skip BFloat16 WMMA on pre-Ampere devices
+                if capability(device()) < v"8.0" && elem_type == "bf16"
+                    continue
+                end
 
             shape = CUDA.WMMA.get_hl_shape(mnk[1], mnk[2], mnk[3])
 
@@ -186,10 +186,10 @@ end
                                                  startswith(ab_elem_type, "u"))
                 continue
             end
-            # Skip BFloat16 WMMA on pre-Ampere devices
-            if capability(device()) < v"8.0" && ab_elem_type == "bf16"
-                continue
-            end
+                # Skip BFloat16 WMMA on pre-Ampere devices
+                if capability(device()) < v"8.0" && ab_elem_type == "bf16"
+                    continue
+                end
 
             # Type-dependent variables
             d_ty  = CUDA.WMMA.map_ptx_to_jl_array[d_elem_type]
@@ -202,9 +202,9 @@ end
             lda_func = getfield(Main, Symbol("llvm_wmma_load_a_$(a_layout)_$(shape)_global_stride_$(ab_elem_type)"))
             ldb_func = getfield(Main, Symbol("llvm_wmma_load_b_$(b_layout)_$(shape)_global_stride_$(ab_elem_type)"))
             ldc_func = getfield(Main, Symbol("llvm_wmma_load_c_col_$(shape)_global_stride_$(c_elem_type)"))
-            # Account for half and int/subint/bf16 mma different naming conventions
-            # Int/subint and bf16 mma functions are distinguished by the a/b element type
-            mma_sym = (d_ty == Int32 || ab_elem_type == "bf16") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
+                # Account for half and int/subint/bf16 mma different naming conventions
+                # Int/subint and bf16 mma functions are distinguished by the a/b element type
+                mma_sym = (d_ty == Int32 || ab_elem_type == "bf16") ? Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(ab_elem_type)") :
                                       Symbol("llvm_wmma_mma_$(a_layout)_$(b_layout)_$(shape)_$(d_elem_type)_$(c_elem_type)")
             mma_func = getfield(Main, mma_sym)
             std_func = getfield(Main, Symbol("llvm_wmma_store_d_col_$(shape)_global_stride_$(d_elem_type)"))
@@ -242,8 +242,8 @@ end
             # Alter test depending on a/b element Type
             if ab_ty == Float16
                 @test new_a * new_b + c ≈ Array(d_dev) rtol=Base.rtoldefault(Float16)
-            elseif ab_ty == BFloat16
-                @test Float32.(new_a) * Float32.(new_b) + c ≈ Array(d_dev) rtol=Base.rtoldefault(BFloat16)
+                elseif ab_ty == BFloat16
+                    @test Float32.(new_a) * Float32.(new_b) + c ≈ Array(d_dev) rtol = Base.rtoldefault(BFloat16)
             else # Cast a and b to prevent UInt8 rollover of resultant data
                 @test Int32.(new_a) * Int32.(new_b) + c == Array(d_dev)
             end
@@ -274,19 +274,19 @@ end
         @test WMMA.unflatten(NTuple{8, NTuple{2, VecElement{Float16}}}, ntuple(i -> Float16(i), 2 * 8)) == ntuple(i -> ntuple(j -> VecElement{Float16}((i-1) * 2 + j), 2), 8)
     end
 
-    @testset "BFloat16 packing/unpacking" begin
-        bf_vals = ntuple(i -> BFloat16(i), 8)
-        packed = WMMA.unflatten_bf16(bf_vals)
-        @test length(packed) == 4
-        unpacked = WMMA.flatten_bf16(packed)
-        @test unpacked == bf_vals
-    end
+        @testset "BFloat16 packing/unpacking" begin
+            bf_vals = ntuple(i -> BFloat16(i), 8)
+            packed = WMMA.unflatten_bf16(bf_vals)
+            @test length(packed) == 4
+            unpacked = WMMA.flatten_bf16(packed)
+            @test unpacked == bf_vals
+        end
 end
 
 ################################################################################
 
 @testset "Broadcasting over fragments: size=$sz, type=$ty" for sz = [1, 2, 5],
-        ty = [Float16, Float32, BFloat16]
+            ty in [Float16, Float32, BFloat16]
         @test ty(5) .* Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 * i), sz))
         @test ty(5) .+ Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(i), sz)) == Fragment{16, 16, 16, sz, ty, RowMajor, MatrixA}(ntuple(i -> ty(5 + i), sz))
 end
@@ -356,125 +356,125 @@ end
 
 ################################################################################
 
-if capability(device()) >= v"8.0"
-@testset "CUDA C-style API (BFloat16)" begin
-    @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
-        b_layout in [ColMajor, RowMajor],
-        c_layout in [ColMajor, RowMajor],
-        d_layout in [ColMajor, RowMajor],
-        do_mac in [true, false]
-
-        a     = rand(BFloat16, (16, 16))
-        b     = rand(BFloat16, (16, 16))
-        c     = rand(Float32, (16, 16))
-        d     = Array{Float32}(undef, (16, 16))
-
-        a_dev = CuArray(a)
-        b_dev = CuArray(b)
-        c_dev = CuArray(c)
-        d_dev = CuArray(d)
-
-        # Note: BFloat16 fragment broadcasting (alpha .* a_frag) requires native bf16
-        # scalar ops which aren't available on all architectures, so we skip scaling
-        @eval function kernel_bf16(a_dev, b_dev, c_dev, d_dev)
-            conf = Config{16, 16, 16, Float32}
-
-            a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
-            b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
-
-            if $do_mac
-                c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
-            else
-                c_frag = fill_c(Float32(0), conf)
-            end
+    if capability(device()) >= v"8.0"
+        @testset "CUDA C-style API (BFloat16)" begin
+            @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
+                    b_layout in [ColMajor, RowMajor],
+                    c_layout in [ColMajor, RowMajor],
+                    d_layout in [ColMajor, RowMajor],
+                    do_mac in [true, false]
+
+                a = rand(BFloat16, (16, 16))
+                b = rand(BFloat16, (16, 16))
+                c = rand(Float32, (16, 16))
+                d = Array{Float32}(undef, (16, 16))
+
+                a_dev = CuArray(a)
+                b_dev = CuArray(b)
+                c_dev = CuArray(c)
+                d_dev = CuArray(d)
+
+                # Note: BFloat16 fragment broadcasting (alpha .* a_frag) requires native bf16
+                # scalar ops which aren't available on all architectures, so we skip scaling
+                @eval function kernel_bf16(a_dev, b_dev, c_dev, d_dev)
+                    conf = Config{16, 16, 16, Float32}
+
+                    a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
+                    b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
+
+                    if $do_mac
+                        c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
+                    else
+                        c_frag = fill_c(Float32(0), conf)
+                    end
 
-            d_frag = mma(a_frag, b_frag, c_frag, conf)
+                    d_frag = mma(a_frag, b_frag, c_frag, conf)
 
-            store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
+                    store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
 
-            return
-        end
+                    return
+                end
 
-        @cuda threads=32 kernel_bf16(a_dev, b_dev, c_dev, d_dev)
-        d = Array(d_dev)
+                @cuda threads = 32 kernel_bf16(a_dev, b_dev, c_dev, d_dev)
+                d = Array(d_dev)
 
-        new_a = (a_layout == ColMajor) ? a : transpose(a)
-        new_b = (b_layout == ColMajor) ? b : transpose(b)
-        new_c = (c_layout == ColMajor) ? c : transpose(c)
-        new_d = (d_layout == ColMajor) ? d : transpose(d)
+                new_a = (a_layout == ColMajor) ? a : transpose(a)
+                new_b = (b_layout == ColMajor) ? b : transpose(b)
+                new_c = (c_layout == ColMajor) ? c : transpose(c)
+                new_d = (d_layout == ColMajor) ? d : transpose(d)
 
-        if do_mac
-            @test Float32.(new_a) * Float32.(new_b) + new_c ≈ new_d rtol=Base.rtoldefault(BFloat16)
-        else
-            @test Float32.(new_a) * Float32.(new_b) ≈ new_d rtol=Base.rtoldefault(BFloat16)
+                if do_mac
+                    @test Float32.(new_a) * Float32.(new_b) + new_c ≈ new_d rtol = Base.rtoldefault(BFloat16)
+                else
+                    @test Float32.(new_a) * Float32.(new_b) ≈ new_d rtol = Base.rtoldefault(BFloat16)
+                end
+            end
         end
     end
-end
-end
-
-# BFloat16 fragment broadcasting requires native bf16 scalar ops (CC 8.9+)
-# On earlier architectures, frag[i] returns UInt32 (packed), causing type mismatch
-if capability(device()) >= v"8.9"
-@testset "CUDA C-style API (BFloat16 with scaling)" begin
-    @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
-        b_layout in [ColMajor, RowMajor],
-        c_layout in [ColMajor, RowMajor],
-        d_layout in [ColMajor, RowMajor],
-        do_mac in [true, false]
-
-        a     = rand(BFloat16, (16, 16))
-        b     = rand(BFloat16, (16, 16))
-        c     = rand(Float32, (16, 16))
-        d     = Array{Float32}(undef, (16, 16))
-
-        a_dev = CuArray(a)
-        b_dev = CuArray(b)
-        c_dev = CuArray(c)
-        d_dev = CuArray(d)
 
-        alpha = rand(BFloat16)
-        beta  = rand(Float32)
-
-        @eval function kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
-            conf = Config{16, 16, 16, Float32}
-
-            a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
-            b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
-
-            if $do_mac
-                c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
-            else
-                c_frag = fill_c(Float32(0), conf)
-            end
+    # BFloat16 fragment broadcasting requires native bf16 scalar ops (CC 8.9+)
+    # On earlier architectures, frag[i] returns UInt32 (packed), causing type mismatch
+    if capability(device()) >= v"8.9"
+        @testset "CUDA C-style API (BFloat16 with scaling)" begin
+            @testset "$(do_mac ? "MAC" : "MUL"): A: $a_layout, B: $b_layout, C: $c_layout, D: $d_layout" for a_layout in [ColMajor, RowMajor],
+                    b_layout in [ColMajor, RowMajor],
+                    c_layout in [ColMajor, RowMajor],
+                    d_layout in [ColMajor, RowMajor],
+                    do_mac in [true, false]
+
+                a = rand(BFloat16, (16, 16))
+                b = rand(BFloat16, (16, 16))
+                c = rand(Float32, (16, 16))
+                d = Array{Float32}(undef, (16, 16))
+
+                a_dev = CuArray(a)
+                b_dev = CuArray(b)
+                c_dev = CuArray(c)
+                d_dev = CuArray(d)
+
+                alpha = rand(BFloat16)
+                beta = rand(Float32)
+
+                @eval function kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
+                    conf = Config{16, 16, 16, Float32}
+
+                    a_frag = load_a(pointer(a_dev), 16, $a_layout, conf)
+                    b_frag = load_b(pointer(b_dev), 16, $b_layout, conf)
+
+                    if $do_mac
+                        c_frag = load_c(pointer(c_dev), 16, $c_layout, conf)
+                    else
+                        c_frag = fill_c(Float32(0), conf)
+                    end
 
-            a_frag = alpha .* a_frag
-            c_frag = beta .* c_frag
+                    a_frag = alpha .* a_frag
+                    c_frag = beta .* c_frag
 
-            d_frag = mma(a_frag, b_frag, c_frag, conf)
+                    d_frag = mma(a_frag, b_frag, c_frag, conf)
 
-            store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
+                    store_d(pointer(d_dev), d_frag, 16, $d_layout, conf)
 
-            return
-        end
+                    return
+                end
 
-        @cuda threads=32 kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
-        d = Array(d_dev)
+                @cuda threads = 32 kernel_bf16_scaled(a_dev, b_dev, c_dev, d_dev, alpha, beta)
+                d = Array(d_dev)
 
-        new_a = (a_layout == ColMajor) ? a : transpose(a)
-        new_b = (b_layout == ColMajor) ? b : transpose(b)
-        new_c = (c_layout == ColMajor) ? c : transpose(c)
-        new_d = (d_layout == ColMajor) ? d : transpose(d)
+                new_a = (a_layout == ColMajor) ? a : transpose(a)
+                new_b = (b_layout == ColMajor) ? b : transpose(b)
+                new_c = (c_layout == ColMajor) ? c : transpose(c)
+                new_d = (d_layout == ColMajor) ? d : transpose(d)
 
-        if do_mac
-            @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) + beta * new_c ≈ new_d rtol=Base.rtoldefault(BFloat16)
-        else
-            @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) ≈ new_d rtol=Base.rtoldefault(BFloat16)
+                if do_mac
+                    @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) + beta * new_c ≈ new_d rtol = Base.rtoldefault(BFloat16)
+                else
+                    @test Float32(alpha) * Float32.(new_a) * Float32.(new_b) ≈ new_d rtol = Base.rtoldefault(BFloat16)
+                end
+            end
         end
     end
-end
-end
 
-################################################################################
+    ################################################################################
 
 @testset "Codegen addressing" begin
     @testset "Global" begin

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA.jl Benchmarks

Details
Benchmark suite Current: 4a9c9f9 Previous: 5d9474a Ratio
latency/precompile 55528201490 ns 55510377029.5 ns 1.00
latency/ttfp 7811188183.5 ns 7790703567 ns 1.00
latency/import 4133329528 ns 4122189304 ns 1.00
integration/volumerhs 9627263 ns 9624973 ns 1.00
integration/byval/slices=1 147065 ns 147064 ns 1.00
integration/byval/slices=3 426264 ns 425893 ns 1.00
integration/byval/reference 145174 ns 145082 ns 1.00
integration/byval/slices=2 286633 ns 286384 ns 1.00
integration/cudadevrt 103611 ns 103602 ns 1.00
kernel/indexing 14234 ns 14225 ns 1.00
kernel/indexing_checked 15096 ns 14969 ns 1.01
kernel/occupancy 795.8971962616822 ns 732.5227272727273 ns 1.09
kernel/launch 2218.3333333333335 ns 2249.4444444444443 ns 0.99
kernel/rand 17364 ns 18642 ns 0.93
array/reverse/1d 20091 ns 19990 ns 1.01
array/reverse/2dL_inplace 66806 ns 66917 ns 1.00
array/reverse/1dL 70286 ns 70158 ns 1.00
array/reverse/2d 21785 ns 21954 ns 0.99
array/reverse/1d_inplace 9770 ns 9677 ns 1.01
array/reverse/2d_inplace 13317 ns 11077 ns 1.20
array/reverse/2dL 73973 ns 74051.5 ns 1.00
array/reverse/1dL_inplace 66896 ns 66880 ns 1.00
array/copy 20472 ns 20660 ns 0.99
array/iteration/findall/int 157793 ns 158373 ns 1.00
array/iteration/findall/bool 139983 ns 140139 ns 1.00
array/iteration/findfirst/int 160728 ns 161271 ns 1.00
array/iteration/findfirst/bool 161750 ns 162049 ns 1.00
array/iteration/scalar 72145 ns 72812.5 ns 0.99
array/iteration/logical 215522.5 ns 216894.5 ns 0.99
array/iteration/findmin/1d 50436 ns 50981 ns 0.99
array/iteration/findmin/2d 96311 ns 96704 ns 1.00
array/reductions/reduce/Int64/1d 43462 ns 43491 ns 1.00
array/reductions/reduce/Int64/dims=1 55142 ns 52642.5 ns 1.05
array/reductions/reduce/Int64/dims=2 61310 ns 61484 ns 1.00
array/reductions/reduce/Int64/dims=1L 89219 ns 88879 ns 1.00
array/reductions/reduce/Int64/dims=2L 87813 ns 87977 ns 1.00
array/reductions/reduce/Float32/1d 36984 ns 37248.5 ns 0.99
array/reductions/reduce/Float32/dims=1 43469 ns 43278 ns 1.00
array/reductions/reduce/Float32/dims=2 59972 ns 60066 ns 1.00
array/reductions/reduce/Float32/dims=1L 52580 ns 52282 ns 1.01
array/reductions/reduce/Float32/dims=2L 72503 ns 72365.5 ns 1.00
array/reductions/mapreduce/Int64/1d 43417 ns 43561 ns 1.00
array/reductions/mapreduce/Int64/dims=1 44553 ns 44306 ns 1.01
array/reductions/mapreduce/Int64/dims=2 61764 ns 61482 ns 1.00
array/reductions/mapreduce/Int64/dims=1L 89214 ns 89001 ns 1.00
array/reductions/mapreduce/Int64/dims=2L 88378 ns 88320 ns 1.00
array/reductions/mapreduce/Float32/1d 38348 ns 38092.5 ns 1.01
array/reductions/mapreduce/Float32/dims=1 41977.5 ns 41962 ns 1.00
array/reductions/mapreduce/Float32/dims=2 59781 ns 60039 ns 1.00
array/reductions/mapreduce/Float32/dims=1L 52818 ns 52636 ns 1.00
array/reductions/mapreduce/Float32/dims=2L 72504 ns 72310 ns 1.00
array/broadcast 20204 ns 20127 ns 1.00
array/copyto!/gpu_to_gpu 12855 ns 12738 ns 1.01
array/copyto!/cpu_to_gpu 215351 ns 217857 ns 0.99
array/copyto!/gpu_to_cpu 285712 ns 287088 ns 1.00
array/accumulate/Int64/1d 125156 ns 124778 ns 1.00
array/accumulate/Int64/dims=1 83876 ns 83708 ns 1.00
array/accumulate/Int64/dims=2 157963 ns 158367 ns 1.00
array/accumulate/Int64/dims=1L 1720455 ns 1710164 ns 1.01
array/accumulate/Int64/dims=2L 967823 ns 967254 ns 1.00
array/accumulate/Float32/1d 109198 ns 109314 ns 1.00
array/accumulate/Float32/dims=1 81088 ns 80184 ns 1.01
array/accumulate/Float32/dims=2 148055.5 ns 147922 ns 1.00
array/accumulate/Float32/dims=1L 1618605 ns 1618786 ns 1.00
array/accumulate/Float32/dims=2L 700756.5 ns 698724 ns 1.00
array/construct 1247 ns 1295.5 ns 0.96
array/random/randn/Float32 43284 ns 47861 ns 0.90
array/random/randn!/Float32 25002 ns 24875 ns 1.01
array/random/rand!/Int64 27480 ns 27408 ns 1.00
array/random/rand!/Float32 8852.333333333334 ns 8909.666666666666 ns 0.99
array/random/rand/Int64 30038 ns 30055 ns 1.00
array/random/rand/Float32 13305 ns 13184 ns 1.01
array/permutedims/4d 55999 ns 55109 ns 1.02
array/permutedims/2d 54238.5 ns 53832 ns 1.01
array/permutedims/3d 55209 ns 54841 ns 1.01
array/sorting/1d 2757318 ns 2757534 ns 1.00
array/sorting/by 3368851 ns 3344541 ns 1.01
array/sorting/2d 1084980 ns 1081521 ns 1.00
cuda/synchronization/stream/auto 1041.9 ns 1036.5 ns 1.01
cuda/synchronization/stream/nonblocking 7094.5 ns 7410.8 ns 0.96
cuda/synchronization/stream/blocking 812.4375 ns 820.6336633663366 ns 0.99
cuda/synchronization/context/auto 1189.2 ns 1154.3 ns 1.03
cuda/synchronization/context/nonblocking 7663 ns 7124.4 ns 1.08
cuda/synchronization/context/blocking 912.2083333333334 ns 887.4107142857143 ns 1.03

This comment was automatically generated by workflow using github-action-benchmark.

@AntonOresten
Copy link
Contributor Author

AntonOresten commented Jan 3, 2026

Tests fail on CC 8.0. I've narrowed it down locally to the C-style API for BFloat16 failing until CC 8.6, starting to pass on CC 8.9, as well as passing on CC 12.0.

EDIT: Fixed. The alpha .* a_frag broadcasting failed because bf16 fragments from load_a store packed UInt32 internally, so BFloat16 * UInt32 type mismatch. Removed scaling from bf16 test, added separate CC 8.9+ test for it. Core WMMA should work on CC 8.0+.

@AntonOresten
Copy link
Contributor Author

Testing the example from #1425:

julia> function kernel_wmma_bf16_lowlevel(a_dev, b_dev, c_dev, d_dev)
           a_frag = WMMA.llvm_wmma_load_a_col_m16n16k16_global_stride_bf16(pointer(a_dev), 16)
           b_frag = WMMA.llvm_wmma_load_b_col_m16n16k16_global_stride_bf16(pointer(b_dev), 16)
           c_frag = WMMA.llvm_wmma_load_c_col_m16n16k16_global_stride_f32(pointer(c_dev), 16)
           d_frag = WMMA.llvm_wmma_mma_col_col_m16n16k16_bf16(a_frag, b_frag, c_frag)
           WMMA.llvm_wmma_store_d_col_m16n16k16_global_stride_f32(pointer(d_dev), d_frag, 16)
           return nothing
       end
kernel_wmma_bf16_lowlevel (generic function with 1 method)

julia> function call_kernel()
           m = n = k = 16
           dtype_a = dtype_b = CUDA.BFloat16
           dtype_c = dtype_d = Float32

           d_a = CUDA.rand(dtype_a, m, k)
           d_b = CUDA.rand(dtype_b, k, n)
           d_c = CUDA.rand(dtype_c, m, n)
           d_d = CUDA.zeros(dtype_d, m, n)

           CUDA.@sync @cuda threads=32 kernel_wmma_bf16_lowlevel(d_a, d_b, d_c, d_d)

           expected = Float32.(Array(d_a)) * Float32.(Array(d_b)) + Array(d_c)
           actual = Array(d_d)
           return (; err=maximum(abs, expected - actual), expected, actual)
       end
call_kernel (generic function with 1 method)

julia> call_kernel()
(err = 4.7683716f-7, expected = Float32[4.347838 3.3425827  4.157455 3.3207057; 6.44958 4.2143664  5.513032 3.220647;  ; 3.962544 2.700366  4.6662664 2.5463972; 3.6238153 3.7428212  4.1363244 2.9692328], actual = Float32[4.347838 3.3425827  4.157455 3.3207054; 6.4495797 4.214366  5.513032 3.2206469;  ; 3.9625437 2.700366  4.666266 2.5463972; 3.6238153 3.7428212  4.1363244 2.9692328])

Carsten's errors from 4 years ago appear to have come from BFloat16 initialization issues in Julia 1.8. BFloat16 seems generally stable now on 1.12.
Beyond the initialization issues, #1425 had map_ptx_to_jl_frag["bf16"] = Float32, but BFloat16 WMMA fragments use packed i32 representation (like integer WMMA), not floats. This PR uses UInt32 instead, matching the LLVM intrinsic signatures.
Tested on CC 8.6, 8.9, and 12.0, all producing correct results with ~5e-7 max error (expected for bf16→f32 tensor core precision).

@codecov
Copy link

codecov bot commented Jan 3, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 89.33%. Comparing base (ca67075) to head (4a9c9f9).
⚠️ Report is 5 commits behind head on master.

Additional details and impacted files
@@             Coverage Diff             @@
##           master    #3009       +/-   ##
===========================================
+ Coverage   76.53%   89.33%   +12.80%     
===========================================
  Files         148      148               
  Lines       12860    12947       +87     
===========================================
+ Hits         9842    11566     +1724     
+ Misses       3018     1381     -1637     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@maleadt
Copy link
Member

maleadt commented Jan 3, 2026

LGTM, thanks.

Out of curiosity, what are you using WMMA for? It's not particularly great, hard to maintain across architectures, and I'm hoping we can instead rely on cuTile to target tensor cores more idiomatically instead.

@maleadt maleadt merged commit 9a7cbd2 into JuliaGPU:master Jan 3, 2026
3 checks passed
@AntonOresten
Copy link
Contributor Author

AntonOresten commented Jan 3, 2026

@pxl-th and I (mostly him) have been working on https://github.com/FluxML/NNop.jl. I was curious to see if we could utilize tensor cores (FluxML/NNop.jl#17) and he made the effort to get it working for CUDA in FluxML/NNop.jl#26 using WMMA, but without this PR, switching would mean no BFloat16.

I'm currently doing nanoGPT speedrunning in Julia, and am thus trying to maximize FLOPs and tokens/sec.

I am indeed aware of #2991, and I'm eagerly awaiting!🙏 Would be happy to test once ready (or any time, really)!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants