diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 5397bce010..77f09e6fa1 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -1490,18 +1490,30 @@ def kernel(X): for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)] for axis in [0, 1] for num_ctas in num_ctas_list - for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']]) + for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']]) def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device): check_type_supported(dtype_x_str, device) + if is_interpreter() and dtype_x_str == 'float16': + pytest.skip('float16 atomic_add does not work in the interpreter mode') shape0, shape1 = shape # triton kernel @triton.jit - def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr): + def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr): off0 = tl.arange(0, SHAPE0) off1 = tl.arange(0, SHAPE1) x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :]) + + if DTYPE == tl.float16: + # sum can have bad numerics when accumulating in float16. + # if we're dealing with float16, do the sum in float32. + x = x.to(tl.float32) + z = tl.sum(x, axis=AXIS) + + if DTYPE == tl.float16: + z = z.to(DTYPE) + if AXIS == 1: old = tl.atomic_add(Z + off0, z) tl.store(OLD + off0, old) @@ -1515,13 +1527,23 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs) old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str)) # reference results - z_ref = z + np.sum(x, axis=axis, keepdims=False) + if x.dtype == np.float16: + # do the sum in float32 to reduce numerical variation + z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype) + else: + z_ref = z + np.sum(x, axis=axis, keepdims=False) old_ref = np.copy(z) # triton result x_tri = to_triton(x, device=device) z_tri = to_triton(z, device=device) old_tri = to_triton(old, device=device) - kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas) + + def torch_to_triton_dtype(t): + if t == torch.float16: + return tl.float16 + return None + + kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas) np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4) np.testing.assert_equal(old_ref, to_numpy(old_tri)) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 77be3a2331..cafca3e123 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -15,6 +15,7 @@ import re import functools import os +import sysconfig # - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func, # and any following whitespace @@ -151,7 +152,8 @@ def triton_key(): # backend libtriton_hash = hashlib.sha256() - with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f: + ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] + with open(os.path.join(TRITON_PATH, f"_C/libtriton.{ext}"), "rb") as f: while True: chunk = f.read(1024**2) if not chunk: diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index ef67338457..de0eb140e2 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -62,3 +62,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : tt.return } } + +// ----- + +#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_f16 + tt.func @atomic_add_f16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked1> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked1>, tensor<256xi32, #blocked1> + // CHECK: llvm.cond_br + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16> + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1> + tt.return + } +} + +// ----- + +#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}> +module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} { + // CHECK-LABEL: atomic_add_bf16 + tt.func @atomic_add_bf16(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) { + %range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2> + %base_ptr = tt.splat %arg0 : !tt.ptr -> tensor<256x!tt.ptr, #blocked2> + %ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr, #blocked2>, tensor<256xi32, #blocked2> + // CHECK: llvm.cond_br + // CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16> + %0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index a45efd4a79..5265f631ad 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -768,7 +768,11 @@ struct AtomicRMWOpConversion // tensor if (tensorTy) { auto valTy = cast(val.getType()); - vec = std::min(vec, valTy.getElementType().isF16() ? 2 : 1); + Type elTy = valTy.getElementType(); + vec = std::min(vec, llvm::isa(elTy) && + elTy.getIntOrFloatBitWidth() == 16 + ? 2 + : 1); // mask numElems = tensorTy.getNumElements(); } @@ -783,13 +787,22 @@ struct AtomicRMWOpConversion auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; SmallVector resultVals(elemsPerThread); - const bool f16v2 = vec == 2 && valueElemTy.isF16(); for (size_t i = 0; i < elemsPerThread; i += vec) { Value rmwPtr = ptrElements[i]; // TODO: in case llMask is zero we can create only one branch for all // elemsPerThread. Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask; + Value operand; + if (vec == 1) { + operand = valElements[i]; + } else { + operand = undef(vecTy); + for (size_t ii = 0; ii < vec; ++ii) + operand = + insert_element(vecTy, operand, valElements[i + ii], i32_val(ii)); + } + Value undefVal = undef(retType); // Build blocks to bypass the atomic instruction for ~rmwMask. auto *curBlock = rewriter.getInsertionBlock(); @@ -806,25 +819,11 @@ struct AtomicRMWOpConversion auto maybeKind = matchAtomicOp(atomicRmwAttr); // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient // atomics for MI-* series of AMD GPU. - Value atom = rewriter - .create( - loc, *maybeKind, rmwPtr, valElements[i], - atomicMemOrdering, StringRef("agent")) - .getResult(); - - // NV for the f16v2 case generates one packed instruction. We have to - // create two separate instructions since LLVM::AtomicRMWOp doesn't - // support this. Can be optimized out with rocdl.raw.buffer.atomic. - if (f16v2) { - Value atom2 = - rewriter - .create( - loc, *maybeKind, ptrElements[i + 1], valElements[i + 1], - atomicMemOrdering, StringRef("agent")) - .getResult(); - auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0)); - atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult(); - } + Value atom = + rewriter + .create(loc, *maybeKind, rmwPtr, operand, + atomicMemOrdering, StringRef("agent")) + .getResult(); if (!tensorTy) { if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = diff --git a/third_party/nvidia/backend/driver.py b/third_party/nvidia/backend/driver.py index 38ce62b0c2..fa832f68ef 100644 --- a/third_party/nvidia/backend/driver.py +++ b/third_party/nvidia/backend/driver.py @@ -1,5 +1,6 @@ import functools import os +import sysconfig import hashlib import subprocess import tempfile @@ -48,7 +49,8 @@ def library_dirs(): def compile_module_from_src(src, name): key = hashlib.sha256(src.encode("utf-8")).hexdigest() cache = get_cache_manager(key) - cache_path = cache.get_file(f"{name}.so") + ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1] + cache_path = cache.get_file(f"{name}.{ext}") if cache_path is None: with tempfile.TemporaryDirectory() as tmpdir: src_path = os.path.join(tmpdir, "main.c") @@ -56,7 +58,7 @@ def compile_module_from_src(src, name): f.write(src) so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries) with open(so, "rb") as f: - cache_path = cache.put(f.read(), f"{name}.so", binary=True) + cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True) import importlib.util spec = importlib.util.spec_from_file_location(name, cache_path) mod = importlib.util.module_from_spec(spec)