|
| 1 | +# ===----------------------------------------------------------------------=== # |
| 2 | +# Copyright (c) 2025, Modular Inc. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License v2.0 with LLVM Exceptions: |
| 5 | +# https://llvm.org/LICENSE.txt |
| 6 | +# |
| 7 | +# Unless required by applicable law or agreed to in writing, software |
| 8 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | +# See the License for the specific language governing permissions and |
| 11 | +# limitations under the License. |
| 12 | +# ===----------------------------------------------------------------------=== # |
| 13 | + |
| 14 | +from gpu.host import DeviceContext |
| 15 | +from gpu.mma import mma |
| 16 | +from testing import assert_equal |
| 17 | + |
| 18 | + |
| 19 | +fn test_mma_bf16_kernel(c_ptr: UnsafePointer[Float32]): |
| 20 | + """BF16×BF16+FP32→FP32 MMA test kernel. |
| 21 | +
|
| 22 | + This test performs matrix multiply-accumulate using BF16 inputs and FP32 |
| 23 | + accumulator. BFloat16 (BF16) is critical for modern LLM inference as it's |
| 24 | + used in Llama 3, Mixtral, and most contemporary transformer models. |
| 25 | +
|
| 26 | + On different GPU architectures, this operation maps to: |
| 27 | + - NVIDIA: Uses tensor core wmma or mma.sync instructions |
| 28 | + - AMD CDNA: Uses mfma instructions |
| 29 | + - AMD RDNA3+: Uses v_wmma_f32_16x16x16_bf16 instructions |
| 30 | + - AMD RDNA1/2: Falls back to scalar operations (no WMMA support) |
| 31 | +
|
| 32 | + IMPORTANT - RDNA3 WMMA Bug (Fixed October 2025): |
| 33 | + RDNA3 WMMA instructions were broken in all LLVM versions 15.0.0-22.0.0git |
| 34 | + for compute kernels (amdgpu_kernel calling convention). Graphics shaders |
| 35 | + worked, but HIP/ROCm compute kernels failed with "Cannot select intrinsic". |
| 36 | +
|
| 37 | + Mojo 25.5.0's LLVM confirmed to have this bug - using `mojo build -o llvm` |
| 38 | + fails during IR generation, preventing workarounds via external llc. |
| 39 | +
|
| 40 | + LLVM Fix Status: |
| 41 | + Submitted upstream: https://github.com/llvm/llvm-project/pull/164036 |
| 42 | + Expected path: Modular will backport fix to Mojo's LLVM |
| 43 | +
|
| 44 | + This test requires either: |
| 45 | + 1. LLVM 23+ with upstreamed fix (after PR merges), OR |
| 46 | + 2. Mojo's LLVM with backported fix (expected), OR |
| 47 | + 3. ROCm's LLVM (TheRock) which already has the fix |
| 48 | +
|
| 49 | + See RDNA3_WMMA_PROJECT_STATUS.md for complete details. |
| 50 | +
|
| 51 | + The test validates that the mma() intrinsic correctly lowers to |
| 52 | + appropriate hardware instructions for the target platform. |
| 53 | +
|
| 54 | + Why BF16 is Important: |
| 55 | + BF16 maintains FP32's exponent range while using half the bits, making |
| 56 | + it ideal for deep learning. Major models using BF16: |
| 57 | + - Meta Llama 3.1/3.2 (8B, 70B, 405B) |
| 58 | + - Mistral 7B v0.3 / Mixtral 8x7B / 8x22B |
| 59 | + - Google Gemma 2B/7B |
| 60 | + - IBM Granite 3.0 8B/20B |
| 61 | +
|
| 62 | + Args: |
| 63 | + c_ptr: Output buffer for results (4 FP32 values). |
| 64 | + """ |
| 65 | + var a_reg = SIMD[DType.bfloat16, 4](1.0, 2.0, 3.0, 4.0) |
| 66 | + var b_reg = SIMD[DType.bfloat16, 4](1.0, 1.0, 1.0, 1.0) |
| 67 | + var c_reg = SIMD[DType.float32, 4](0.0, 0.0, 0.0, 0.0) |
| 68 | + var d_reg = SIMD[DType.float32, 4](0.0, 0.0, 0.0, 0.0) |
| 69 | + |
| 70 | + mma(d_reg, a_reg, b_reg, c_reg) |
| 71 | + |
| 72 | + c_ptr[0] = d_reg[0] |
| 73 | + c_ptr[1] = d_reg[1] |
| 74 | + c_ptr[2] = d_reg[2] |
| 75 | + c_ptr[3] = d_reg[3] |
| 76 | + |
| 77 | + |
| 78 | +def main(): |
| 79 | + """Test BF16 matrix multiply-accumulate operation.""" |
| 80 | + with DeviceContext() as ctx: |
| 81 | + var c_device = ctx.enqueue_create_buffer[DType.float32](4) |
| 82 | + var c_host = UnsafePointer[Float32].alloc(4) |
| 83 | + |
| 84 | + for i in range(4): |
| 85 | + c_host[i] = -1.0 |
| 86 | + |
| 87 | + ctx.enqueue_copy(c_device, c_host) |
| 88 | + |
| 89 | + alias kernel = test_mma_bf16_kernel |
| 90 | + |
| 91 | + ctx.enqueue_function_checked[kernel, kernel]( |
| 92 | + c_device, |
| 93 | + grid_dim=1, |
| 94 | + block_dim=64, |
| 95 | + ) |
| 96 | + |
| 97 | + ctx.enqueue_copy(c_host, c_device) |
| 98 | + ctx.synchronize() |
| 99 | + |
| 100 | + for i in range(4): |
| 101 | + assert_equal(c_host[i] != -1.0, True) |
| 102 | + |
| 103 | + _ = c_device |
| 104 | + c_host.free() |
0 commit comments