Skip to content

Commit 0df7d80

Browse files
authored
Revert "Merge OpenAI Triton commit 78c8054 (#2595)" (#2602)
This reverts commit ef407fb. Accidentally made squash. After this revert I will simply repeat #2595 and make a merge. Sorry for that.
1 parent ef407fb commit 0df7d80

File tree

5 files changed

+28
-85
lines changed

5 files changed

+28
-85
lines changed

python/test/unit/language/test_core.py

Lines changed: 4 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1490,30 +1490,18 @@ def kernel(X):
14901490
for shape in [(2, 2), (2, 8), (8, 2), (8, 8), (32, 32), (64, 64)]
14911491
for axis in [0, 1]
14921492
for num_ctas in num_ctas_list
1493-
for dtype_x_str in ['float16', 'float32', 'uint64', 'int64', 'float64']])
1493+
for dtype_x_str in ['float32', 'uint64', 'int64', 'float64']])
14941494
def test_tensor_atomic_rmw(shape, axis, num_ctas, dtype_x_str, device):
14951495
check_type_supported(dtype_x_str, device)
1496-
if is_interpreter() and dtype_x_str == 'float16':
1497-
pytest.skip('float16 atomic_add does not work in the interpreter mode')
14981496
shape0, shape1 = shape
14991497
# triton kernel
15001498

15011499
@triton.jit
1502-
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr, DTYPE: tl.constexpr):
1500+
def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
15031501
off0 = tl.arange(0, SHAPE0)
15041502
off1 = tl.arange(0, SHAPE1)
15051503
x = tl.load(X + off0[:, None] * SHAPE1 + off1[None, :])
1506-
1507-
if DTYPE == tl.float16:
1508-
# sum can have bad numerics when accumulating in float16.
1509-
# if we're dealing with float16, do the sum in float32.
1510-
x = x.to(tl.float32)
1511-
15121504
z = tl.sum(x, axis=AXIS)
1513-
1514-
if DTYPE == tl.float16:
1515-
z = z.to(DTYPE)
1516-
15171505
if AXIS == 1:
15181506
old = tl.atomic_add(Z + off0, z)
15191507
tl.store(OLD + off0, old)
@@ -1527,23 +1515,13 @@ def kernel(Z, X, OLD, AXIS: tl.constexpr, SHAPE0: tl.constexpr, SHAPE1: tl.const
15271515
z = numpy_random(z_shape, dtype_str=dtype_x_str, rs=rs)
15281516
old = np.zeros(z_shape, dtype=getattr(np, dtype_x_str))
15291517
# reference results
1530-
if x.dtype == np.float16:
1531-
# do the sum in float32 to reduce numerical variation
1532-
z_ref = z + np.sum(x.astype(np.float32), axis=axis, keepdims=False).astype(x.dtype)
1533-
else:
1534-
z_ref = z + np.sum(x, axis=axis, keepdims=False)
1518+
z_ref = z + np.sum(x, axis=axis, keepdims=False)
15351519
old_ref = np.copy(z)
15361520
# triton result
15371521
x_tri = to_triton(x, device=device)
15381522
z_tri = to_triton(z, device=device)
15391523
old_tri = to_triton(old, device=device)
1540-
1541-
def torch_to_triton_dtype(t):
1542-
if t == torch.float16:
1543-
return tl.float16
1544-
return None
1545-
1546-
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, torch_to_triton_dtype(x_tri.dtype), num_ctas=num_ctas)
1524+
kernel[(1, )](z_tri, x_tri, old_tri, axis, shape0, shape1, num_ctas=num_ctas)
15471525
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
15481526
np.testing.assert_equal(old_ref, to_numpy(old_tri))
15491527

python/triton/compiler/compiler.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import re
1616
import functools
1717
import os
18-
import sysconfig
1918

2019
# - ^\s*tt\.func\s+ : match the start of the string, any leading whitespace, the keyword func,
2120
# and any following whitespace
@@ -152,8 +151,7 @@ def triton_key():
152151

153152
# backend
154153
libtriton_hash = hashlib.sha256()
155-
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
156-
with open(os.path.join(TRITON_PATH, f"_C/libtriton.{ext}"), "rb") as f:
154+
with open(os.path.join(TRITON_PATH, "_C/libtriton.so"), "rb") as f:
157155
while True:
158156
chunk = f.read(1024**2)
159157
if not chunk:

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -62,35 +62,3 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
6262
tt.return
6363
}
6464
}
65-
66-
// -----
67-
68-
#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
69-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
70-
// CHECK-LABEL: atomic_add_f16
71-
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
72-
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
73-
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
74-
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
75-
// CHECK: llvm.cond_br
76-
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
77-
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
78-
tt.return
79-
}
80-
}
81-
82-
// -----
83-
84-
#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
85-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
86-
// CHECK-LABEL: atomic_add_bf16
87-
tt.func @atomic_add_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
88-
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
89-
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
90-
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
91-
// CHECK: llvm.cond_br
92-
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
93-
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
94-
tt.return
95-
}
96-
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -768,11 +768,7 @@ struct AtomicRMWOpConversion
768768
// tensor
769769
if (tensorTy) {
770770
auto valTy = cast<RankedTensorType>(val.getType());
771-
Type elTy = valTy.getElementType();
772-
vec = std::min<unsigned>(vec, llvm::isa<FloatType>(elTy) &&
773-
elTy.getIntOrFloatBitWidth() == 16
774-
? 2
775-
: 1);
771+
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
776772
// mask
777773
numElems = tensorTy.getNumElements();
778774
}
@@ -787,22 +783,13 @@ struct AtomicRMWOpConversion
787783
auto vecTy = vec_ty(valueElemTy, vec);
788784
auto retType = vec == 1 ? valueElemTy : vecTy;
789785
SmallVector<Value> resultVals(elemsPerThread);
786+
const bool f16v2 = vec == 2 && valueElemTy.isF16();
790787
for (size_t i = 0; i < elemsPerThread; i += vec) {
791788
Value rmwPtr = ptrElements[i];
792789
// TODO: in case llMask is zero we can create only one branch for all
793790
// elemsPerThread.
794791
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
795792

796-
Value operand;
797-
if (vec == 1) {
798-
operand = valElements[i];
799-
} else {
800-
operand = undef(vecTy);
801-
for (size_t ii = 0; ii < vec; ++ii)
802-
operand =
803-
insert_element(vecTy, operand, valElements[i + ii], i32_val(ii));
804-
}
805-
806793
Value undefVal = undef(retType);
807794
// Build blocks to bypass the atomic instruction for ~rmwMask.
808795
auto *curBlock = rewriter.getInsertionBlock();
@@ -819,11 +806,25 @@ struct AtomicRMWOpConversion
819806
auto maybeKind = matchAtomicOp(atomicRmwAttr);
820807
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
821808
// atomics for MI-* series of AMD GPU.
822-
Value atom =
823-
rewriter
824-
.create<LLVM::AtomicRMWOp>(loc, *maybeKind, rmwPtr, operand,
825-
atomicMemOrdering, StringRef("agent"))
826-
.getResult();
809+
Value atom = rewriter
810+
.create<LLVM::AtomicRMWOp>(
811+
loc, *maybeKind, rmwPtr, valElements[i],
812+
atomicMemOrdering, StringRef("agent"))
813+
.getResult();
814+
815+
// NV for the f16v2 case generates one packed instruction. We have to
816+
// create two separate instructions since LLVM::AtomicRMWOp doesn't
817+
// support this. Can be optimized out with rocdl.raw.buffer.atomic.
818+
if (f16v2) {
819+
Value atom2 =
820+
rewriter
821+
.create<LLVM::AtomicRMWOp>(
822+
loc, *maybeKind, ptrElements[i + 1], valElements[i + 1],
823+
atomicMemOrdering, StringRef("agent"))
824+
.getResult();
825+
auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0));
826+
atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult();
827+
}
827828
if (!tensorTy) {
828829
if (atomicNeedsSharedMemory(op.getResult())) {
829830
Value atomPtr =

third_party/nvidia/backend/driver.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import os
3-
import sysconfig
43
import hashlib
54
import subprocess
65
import tempfile
@@ -49,16 +48,15 @@ def library_dirs():
4948
def compile_module_from_src(src, name):
5049
key = hashlib.sha256(src.encode("utf-8")).hexdigest()
5150
cache = get_cache_manager(key)
52-
ext = sysconfig.get_config_var("EXT_SUFFIX").split(".")[-1]
53-
cache_path = cache.get_file(f"{name}.{ext}")
51+
cache_path = cache.get_file(f"{name}.so")
5452
if cache_path is None:
5553
with tempfile.TemporaryDirectory() as tmpdir:
5654
src_path = os.path.join(tmpdir, "main.c")
5755
with open(src_path, "w") as f:
5856
f.write(src)
5957
so = _build(name, src_path, tmpdir, library_dirs(), include_dir, libraries)
6058
with open(so, "rb") as f:
61-
cache_path = cache.put(f.read(), f"{name}.{ext}", binary=True)
59+
cache_path = cache.put(f.read(), f"{name}.so", binary=True)
6260
import importlib.util
6361
spec = importlib.util.spec_from_file_location(name, cache_path)
6462
mod = importlib.util.module_from_spec(spec)

0 commit comments

Comments
 (0)