Skip to content

Commit c5beb57

Browse files
authored
Merge OpenAI Triton commit 78c8054 (#2604)
This PR change the Triton base from 152ef2d to 78c8054 (Oct 27). Pass rate: `99.84%` Please do not squash and merge this PR. Repeating #2595.
2 parents c84ee6b + c7607d2 commit c5beb57

File tree

5 files changed

+85
-28
lines changed

5 files changed

+85
-28
lines changed

python/test/unit/language/test_core.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,18 +1490,30 @@ def kernel(X):
14901490
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
14911491
for axis in [0, 1]
14921492
for num_ctas in num_ctas_list
1493-
for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']])
1493+
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
14941494
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
14951495
check_type_supported(dtype_x_str, device)
1496+
if is_interpreter() and dtype_x_str == 'float16':
1497+
pytest.skip('float16 atomic_add does not work in the interpreter mode')
14961498
shape0, shape1 = shape
14971499
# triton kernel
14981500

14991501
@triton.jit
1500-
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
1502+
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr):
15011503
off0 = tl.arange(0, SHAPE0)
15021504
off1 = tl.arange(0, SHAPE1)
15031505
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
1506+
1507+
if DTYPE == tl.float16:
1508+
# sum can have bad numerics when accumulating in float16.
1509+
# if we're dealing with float16, do the sum in float32.
1510+
x = x.to(tl.float32)
1511+
15041512
z = tl.sum(x, axis=AXIS)
1513+
1514+
if DTYPE == tl.float16:
1515+
z = z.to(DTYPE)
1516+
15051517
if AXIS == 1:
15061518
old = tl.atomic_add(Z + off0, z)
15071519
tl.store(OLD + off0, old)
@@ -1515,13 +1527,23 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
15151527
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
15161528
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
15171529
# reference results
1518-
z_ref = z + np.sum(x, axis=axis, keepdims=False)
1530+
if x.dtype == np.float16:
1531+
# do the sum in float32 to reduce numerical variation
1532+
z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype)
1533+
else:
1534+
z_ref = z + np.sum(x, axis=axis, keepdims=False)
15191535
old_ref = np.copy(z)
15201536
# triton result
15211537
x_tri = to_triton(x, device=device)
15221538
z_tri = to_triton(z, device=device)
15231539
old_tri = to_triton(old, device=device)
1524-
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas)
1540+
1541+
def torch_to_triton_dtype(t):
1542+
if t == torch.float16:
1543+
return tl.float16
1544+
return None
1545+
1546+
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas)
15251547
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
15261548
np.testing.assert_equal(old_ref, to_numpy(old_tri))
15271549

python/triton/compiler/compiler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import re
1616
import functools
1717
import os
18+
import sysconfig
1819

1920
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
2021
# and any following whitespace
@@ -151,7 +152,8 @@ def triton_key():
151152

152153
# backend
153154
libtriton_hash = hashlib.sha256()
154-
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
155+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
156+
with open(os.path.join(TRITON_PATH, f"_C/libtriton.{ext}"), "rb") as f:
155157
while True:
156158
chunk = f.read(1024**2)
157159
if not chunk:

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
6262
tt.return
6363
}
6464
}
65+
66+
// -----
67+
68+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
69+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
70+
// CHECK-LABEL: atomic_add_f16
71+
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
72+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
73+
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
74+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
75+
// CHECK: llvm.cond_br
76+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
77+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
78+
tt.return
79+
}
80+
}
81+
82+
// -----
83+
84+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
85+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
86+
// CHECK-LABEL: atomic_add_bf16
87+
tt.func @atomic_add_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
88+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
89+
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
90+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
91+
// CHECK: llvm.cond_br
92+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
93+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
94+
tt.return
95+
}
96+
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,11 @@ struct AtomicRMWOpConversion
768768
// tensor
769769
if (tensorTy) {
770770
auto valTy = cast<RankedTensorType>(val.getType());
771-
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
771+
Type elTy = valTy.getElementType();
772+
vec = std::min<unsigned>(vec, llvm::isa<FloatType>(elTy) &&
773+
elTy.getIntOrFloatBitWidth() == 16
774+
? 2
775+
: 1);
772776
// mask
773777
numElems = tensorTy.getNumElements();
774778
}
@@ -783,13 +787,22 @@ struct AtomicRMWOpConversion
783787
auto vecTy = vec_ty(valueElemTy, vec);
784788
auto retType = vec == 1 ? valueElemTy : vecTy;
785789
SmallVector<Value> resultVals(elemsPerThread);
786-
const bool f16v2 = vec == 2 && valueElemTy.isF16();
787790
for (size_t i = 0; i < elemsPerThread; i += vec) {
788791
Value rmwPtr = ptrElements[i];
789792
// TODO: in case llMask is zero we can create only one branch for all
790793
// elemsPerThread.
791794
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
792795

796+
Value operand;
797+
if (vec == 1) {
798+
operand = valElements[i];
799+
} else {
800+
operand = undef(vecTy);
801+
for (size_t ii = 0; ii < vec; ++ii)
802+
operand =
803+
insert_element(vecTy, operand, valElements[i + ii], i32_val(ii));
804+
}
805+
793806
Value undefVal = undef(retType);
794807
// Build blocks to bypass the atomic instruction for ~rmwMask.
795808
auto *curBlock = rewriter.getInsertionBlock();
@@ -806,25 +819,11 @@ struct AtomicRMWOpConversion
806819
auto maybeKind = matchAtomicOp(atomicRmwAttr);
807820
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
808821
// atomics for MI-* series of AMD GPU.
809-
Value atom = rewriter
810-
.create<LLVM::AtomicRMWOp>(
811-
loc, *maybeKind, rmwPtr, valElements[i],
812-
atomicMemOrdering, StringRef("agent"))
813-
.getResult();
814-
815-
// NV for the f16v2 case generates one packed instruction. We have to
816-
// create two separate instructions since LLVM::AtomicRMWOp doesn't
817-
// support this. Can be optimized out with rocdl.raw.buffer.atomic.
818-
if (f16v2) {
819-
Value atom2 =
820-
rewriter
821-
.create<LLVM::AtomicRMWOp>(
822-
loc, *maybeKind, ptrElements[i + 1], valElements[i + 1],
823-
atomicMemOrdering, StringRef("agent"))
824-
.getResult();
825-
auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0));
826-
atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult();
827-
}
822+
Value atom =
823+
rewriter
824+
.create<LLVM::AtomicRMWOp>(loc, *maybeKind, rmwPtr, operand,
825+
atomicMemOrdering, StringRef("agent"))
826+
.getResult();
828827
if (!tensorTy) {
829828
if (atomicNeedsSharedMemory(op.getResult())) {
830829
Value atomPtr =

third_party/nvidia/backend/driver.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import functools
22
import os
3+
import sysconfig
34
import hashlib
45
import subprocess
56
import tempfile
@@ -48,15 +49,16 @@ def library_dirs():
4849
def compile_module_from_src(src, name):
4950
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
5051
cache = get_cache_manager(key)
51-
cache_path = cache.get_file(f"{name}.so")
52+
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
53+
cache_path = cache.get_file(f"{name}.{ext}")
5254
if cache_path is None:
5355
with tempfile.TemporaryDirectory() as tmpdir:
5456
src_path = os.path.join(tmpdir, "main.c")
5557
with open(src_path, "w") as f:
5658
f.write(src)
5759
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
5860
with open(so, "rb") as f:
59-
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
61+
cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
6062
import importlib.util
6163
spec = importlib.util.spec_from_file_location(name, cache_path)
6264
mod = importlib.util.module_from_spec(spec)

0 commit comments

Comments
 (0)