Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,18 +1490,30 @@ def kernel(X):
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
for axis in [0, 1]
for num_ctas in num_ctas_list
for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']])
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
check_type_supported(dtype_x_str, device)
if is_interpreter() and dtype_x_str == 'float16':
pytest.skip('float16 atomic_add does not work in the interpreter mode')
shape0, shape1 = shape
# triton kernel

@triton.jit
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr):
off0 = tl.arange(0, SHAPE0)
off1 = tl.arange(0, SHAPE1)
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])

if DTYPE == tl.float16:
# sum can have bad numerics when accumulating in float16.
# if we're dealing with float16, do the sum in float32.
x = x.to(tl.float32)

z = tl.sum(x, axis=AXIS)

if DTYPE == tl.float16:
z = z.to(DTYPE)

if AXIS == 1:
old = tl.atomic_add(Z + off0, z)
tl.store(OLD + off0, old)
Expand All @@ -1515,13 +1527,23 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
# reference results
z_ref = z + np.sum(x, axis=axis, keepdims=False)
if x.dtype == np.float16:
# do the sum in float32 to reduce numerical variation
z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype)
else:
z_ref = z + np.sum(x, axis=axis, keepdims=False)
old_ref = np.copy(z)
# triton result
x_tri = to_triton(x, device=device)
z_tri = to_triton(z, device=device)
old_tri = to_triton(old, device=device)
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas)

def torch_to_triton_dtype(t):
if t == torch.float16:
return tl.float16
return None

kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas)
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
np.testing.assert_equal(old_ref, to_numpy(old_tri))

Expand Down
4 changes: 3 additions & 1 deletion python/triton/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import re
import functools
import os
import sysconfig

# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
# and any following whitespace
Expand Down Expand Up @@ -151,7 +152,8 @@ def triton_key():

# backend
libtriton_hash = hashlib.sha256()
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
with open(os.path.join(TRITON_PATH, f"_C/libtriton.{ext}"), "rb") as f:
while True:
chunk = f.read(1024**2)
if not chunk:
Expand Down
32 changes: 32 additions & 0 deletions test/Conversion/amd/tritongpu_to_llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
tt.return
}
}

// -----

#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_f16
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
// CHECK: llvm.cond_br
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
tt.return
}
}

// -----

#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
// CHECK-LABEL: atomic_add_bf16
tt.func @atomic_add_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
// CHECK: llvm.cond_br
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
tt.return
}
}
41 changes: 20 additions & 21 deletions third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -768,7 +768,11 @@ struct AtomicRMWOpConversion
// tensor
if (tensorTy) {
auto valTy = cast<RankedTensorType>(val.getType());
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
Type elTy = valTy.getElementType();
vec = std::min<unsigned>(vec, llvm::isa<FloatType>(elTy) &&
elTy.getIntOrFloatBitWidth() == 16
? 2
: 1);
// mask
numElems = tensorTy.getNumElements();
}
Expand All @@ -783,13 +787,22 @@ struct AtomicRMWOpConversion
auto vecTy = vec_ty(valueElemTy, vec);
auto retType = vec == 1 ? valueElemTy : vecTy;
SmallVector<Value> resultVals(elemsPerThread);
const bool f16v2 = vec == 2 && valueElemTy.isF16();
for (size_t i = 0; i < elemsPerThread; i += vec) {
Value rmwPtr = ptrElements[i];
// TODO: in case llMask is zero we can create only one branch for all
// elemsPerThread.
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;

Value operand;
if (vec == 1) {
operand = valElements[i];
} else {
operand = undef(vecTy);
for (size_t ii = 0; ii < vec; ++ii)
operand =
insert_element(vecTy, operand, valElements[i + ii], i32_val(ii));
}

Value undefVal = undef(retType);
// Build blocks to bypass the atomic instruction for ~rmwMask.
auto *curBlock = rewriter.getInsertionBlock();
Expand All @@ -806,25 +819,11 @@ struct AtomicRMWOpConversion
auto maybeKind = matchAtomicOp(atomicRmwAttr);
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
// atomics for MI-* series of AMD GPU.
Value atom = rewriter
.create<LLVM::AtomicRMWOp>(
loc, *maybeKind, rmwPtr, valElements[i],
atomicMemOrdering, StringRef("agent"))
.getResult();

// NV for the f16v2 case generates one packed instruction. We have to
// create two separate instructions since LLVM::AtomicRMWOp doesn't
// support this. Can be optimized out with rocdl.raw.buffer.atomic.
if (f16v2) {
Value atom2 =
rewriter
.create<LLVM::AtomicRMWOp>(
loc, *maybeKind, ptrElements[i + 1], valElements[i + 1],
atomicMemOrdering, StringRef("agent"))
.getResult();
auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0));
atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult();
}
Value atom =
rewriter
.create<LLVM::AtomicRMWOp>(loc, *maybeKind, rmwPtr, operand,
atomicMemOrdering, StringRef("agent"))
.getResult();
if (!tensorTy) {
if (atomicNeedsSharedMemory(op.getResult())) {
Value atomPtr =
Expand Down
6 changes: 4 additions & 2 deletions third_party/nvidia/backend/driver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import os
import sysconfig
import hashlib
import subprocess
import tempfile
Expand Down Expand Up @@ -48,15 +49,16 @@ def library_dirs():
def compile_module_from_src(src, name):
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
cache = get_cache_manager(key)
cache_path = cache.get_file(f"{name}.so")
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
cache_path = cache.get_file(f"{name}.{ext}")
if cache_path is None:
with tempfile.TemporaryDirectory() as tmpdir:
src_path = os.path.join(tmpdir, "main.c")
with open(src_path, "w") as f:
f.write(src)
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
with open(so, "rb") as f:
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
import importlib.util
spec = importlib.util.spec_from_file_location(name, cache_path)
mod = importlib.util.module_from_spec(spec)
Expand Down