Skip to content

Commit 62ece72

Browse files
Merge OpenAI Triton commit 19277de (#4639)
This PR change the Triton base from 00d5ca7 to 19277de (Jul 8). Pass rate: 97.08%
2 parents 148b0a9 + e35f0a0 commit 62ece72

File tree

14 files changed

+375
-67
lines changed

14 files changed

+375
-67
lines changed

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,11 @@ SmallVector<SmallVector<Value>>
516516
emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
517517
Attribute layout, RankedTensorType type, bool withCTAOffset);
518518

519+
// Emits the required padding in elements for the given shared memory offset
520+
Value emitPadding(Location loc, RewriterBase &rewriter,
521+
triton::gpu::PaddedSharedEncodingAttr layout,
522+
Value smemOffset);
523+
519524
// Emits IR to load data from shared memory into registers, or to store data
520525
// from registers into shared memory.
521526
//

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 53 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
113113
auto nRow = A.getTotalOutDimSizeLog2();
114114
SmallVector<int32_t> matrix = flatten(A.getBases().begin()->second);
115115
assert(matrix.size() == nCol);
116+
116117
// We iterate the matrix following the diagonals
117118
// The idea here is that we want to generate code of the form:
118119
// \xor_i (x & mask_i) << s_i
@@ -133,15 +134,50 @@ Value matrixVectorProd(TritonLLVMOpBuilder &b, const LinearLayout &A, Value x) {
133134
return mask;
134135
};
135136

137+
uint32_t explicitCols = 0;
138+
139+
{
140+
SmallVector<uint32_t> masks;
141+
for (int i = -nRow + 1; i < nCol; i++) {
142+
masks.push_back(getMask(i));
143+
}
144+
bool reachedFixedPoint = false;
145+
while (!reachedFixedPoint) {
146+
reachedFixedPoint = true;
147+
for (uint32_t m : masks) {
148+
uint32_t c = m & ~explicitCols;
149+
if ((c != 0) && ((c & (c - 1)) == 0)) {
150+
// found a single-element diagonal
151+
explicitCols |= c;
152+
reachedFixedPoint = false;
153+
}
154+
}
155+
}
156+
}
157+
158+
// handle any diagonals that have survived
136159
Value ret = b.i32_val(0);
137160
for (int i = -nRow + 1; i < nCol; i++) {
138-
auto mask = getMask(i);
161+
auto mask = getMask(i) & ~explicitCols;
139162
if (mask == 0)
140163
continue;
141164
auto masked = b.and_(x, b.i32_val(mask));
142165
ret = b.xor_(ret, i >= 0 ? Value(b.lshr(masked, b.i32_val(i)))
143166
: Value(b.shl(masked, b.i32_val(-i))));
144167
}
168+
169+
// handle any explicit columns:
170+
Value zero = b.i32_val(0);
171+
for (int i = 0; i < nCol; i++) {
172+
if ((explicitCols >> i) & 1) {
173+
Value bit = b.and_(x, b.i32_val(1 << i));
174+
Value bit_is_zero = b.icmp_eq(bit, zero);
175+
int32_t basis = matrix[i];
176+
if (basis == 0)
177+
continue;
178+
ret = b.xor_(ret, b.select(bit_is_zero, zero, b.i32_val(basis)));
179+
}
180+
}
145181
return ret;
146182
}
147183

@@ -388,6 +424,21 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
388424
return ret;
389425
}
390426

427+
Value emitPadding(Location loc, RewriterBase &rewriter,
428+
triton::gpu::PaddedSharedEncodingAttr layout,
429+
Value smemOffset) {
430+
TritonLLVMOpBuilder b(loc, rewriter);
431+
432+
Value padOffset = b.i32_val(0);
433+
for (auto [interval, padding] :
434+
llvm::zip_equal(layout.getIntervals(), layout.getPaddings())) {
435+
Value iVal = b.i32_val(llvm::Log2_32(interval));
436+
Value pVal = b.i32_val(llvm::Log2_32(padding));
437+
padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal));
438+
}
439+
return padOffset;
440+
}
441+
391442
namespace {
392443

393444
Value getSmemVecAddr(const LinearLayout &regLayout,
@@ -478,13 +529,7 @@ Value getSmemVecAddr(const LinearLayout &regLayout,
478529
if (auto paddedLayout =
479530
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc)) {
480531
// Apply the offset needed for padding.
481-
Value padOffset = b.i32_val(0);
482-
for (auto [interval, padding] : llvm::zip_equal(
483-
paddedLayout.getIntervals(), paddedLayout.getPaddings())) {
484-
Value iVal = b.i32_val(llvm::Log2_32(interval));
485-
Value pVal = b.i32_val(llvm::Log2_32(padding));
486-
padOffset = b.add(padOffset, b.shl(b.ashr(smemOffset, iVal), pVal));
487-
}
532+
Value padOffset = emitPadding(loc, rewriter, paddedLayout, smemOffset);
488533
smemOffset = b.add(smemOffset, padOffset);
489534
}
490535
} else { // Case 2 -> rank-reduced swizzling

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,13 @@ struct MemDescSubviewOpConversion
513513
.second;
514514
}
515515

516+
if (auto paddedLayout = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
517+
srcTy.getEncoding())) {
518+
// Apply padding based on the computed offset
519+
Value padOffset = emitPadding(loc, rewriter, paddedLayout, offset);
520+
offset = b.add(offset, padOffset);
521+
}
522+
516523
auto base = smemObj.getBase();
517524
auto elemPtrTy = base.getType();
518525
smemObj = SharedMemoryObject(b.gep(elemPtrTy, llvmElemTy, base, offset),

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1467,6 +1467,7 @@ void eraseLoopCarriedValues(scf::ForOp &loop, llvm::BitVector indices) {
14671467
namespace mlir::triton {
14681468
void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
14691469
Value val) {
1470+
OpBuilder::InsertionGuard guard(builder);
14701471
SmallVector<Operation *> opsToDelete;
14711472
SmallVector<OpOperand *> operandsToReplace;
14721473

@@ -1487,7 +1488,6 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
14871488

14881489
Operation *user = use.getOwner();
14891490
// `subview(old_op)` is replaced by a new `subview(val)`.
1490-
OpBuilder::InsertionGuard g(builder);
14911491
builder.setInsertionPoint(user);
14921492
Value newVal;
14931493
if (auto subview = dyn_cast<ttg::MemDescSubviewOp>(user)) {

python/test/unit/language/test_compile_errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -398,9 +398,9 @@ def test_min_dot_size(dtype):
398398
error_msg = "Input shapes should have "
399399
if is_cuda():
400400
if dtype.primitive_bitwidth == 8:
401-
error_msg += "M >= 16, N >= 16 and K >= 32"
401+
error_msg += "M >= 16, N >= 8 and K >= 32"
402402
else:
403-
error_msg = "M >= 16, N >= 16 and K >= 16"
403+
error_msg = "M >= 16, N >= 8 and K >= 16"
404404
elif is_hip():
405405
# hip supports arbitrary sizes
406406
error_msg = None

python/test/unit/language/test_conversions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,8 @@ def test_typeconvert_downcast(src_dtype, dst_dtype, rounding, max_repr, device):
382382
if dst_dtype == 'float8e4nv':
383383
if not rounding == 'rtne':
384384
pytest.skip("float8e4nv downcast tests only supported with RTNE rounding on AMDGPU")
385-
if not (is_hip_cdna3() and src_dtype == 'float16' or is_hip_cdna4()):
386-
pytest.skip("float8e4nv downcast tests only supported on AMDGPU CDNA3 or on CDNA4 and from float16 with RTNE rounding")
385+
if not is_hip_cdna4() and src_dtype == 'bfloat16':
386+
pytest.skip("float8e4nv downcast tests from bfloat16 only supported on AMDGPU CDNA4")
387387

388388
if dst_dtype in ('float8e5b16', 'float8e4b8') and rounding == 'rtne' and not is_hip_cdna3():
389389
pytest.skip(f"{dst_dtype} downcast with RTNE rounding tests only supported on AMDGPU CDNA3")

python/test/unit/language/test_core.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2436,6 +2436,49 @@ def kernel(X, Z, BLOCK: tl.constexpr):
24362436
assert z[0] == 0
24372437

24382438

2439+
@pytest.mark.interpreter
2440+
def test_max_min_with_nan(device):
2441+
# In triton, we implement a "nan ignore" style, which means if there is NaN
2442+
# in the reduce dimesion, we should ignore it and return the max/min number,
2443+
# it's different with torch.max/min.
2444+
@triton.jit
2445+
def max_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
2446+
offsets = tl.arange(0, BLOCK_SIZE)
2447+
x = tl.load(x_ptr + offsets)
2448+
2449+
max_val = tl.max(x, axis=0)
2450+
2451+
if tl.program_id(0) == 0:
2452+
tl.store(y_ptr, max_val)
2453+
2454+
@triton.jit
2455+
def min_kernel(x_ptr, y_ptr, BLOCK_SIZE: tl.constexpr):
2456+
offsets = tl.arange(0, BLOCK_SIZE)
2457+
x = tl.load(x_ptr + offsets)
2458+
2459+
min_val = tl.min(x, axis=0)
2460+
2461+
if tl.program_id(0) == 0:
2462+
tl.store(y_ptr, min_val)
2463+
2464+
BLOCK_SIZE = 64
2465+
x = torch.rand((1, BLOCK_SIZE), dtype=torch.float32, device=device)
2466+
# Not the expected output for tl.max
2467+
x[0, 0] = float('nan')
2468+
# Expected output for tl.min
2469+
x[0, 1] = float('-inf')
2470+
# Expected output for tl.max
2471+
x[0, 2] = float('inf')
2472+
2473+
y = torch.ones(1, device=device)
2474+
2475+
max_kernel[(1, )](x, y, BLOCK_SIZE=BLOCK_SIZE)
2476+
assert y[0] == float('inf')
2477+
2478+
min_kernel[(1, )](x, y, BLOCK_SIZE=BLOCK_SIZE)
2479+
assert y[0] == float('-inf')
2480+
2481+
24392482
def get_reduced_dtype(dtype_str, op):
24402483
if op in ('argmin', 'argmax'):
24412484
return 'int32'

python/triton/runtime/interpreter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -934,9 +934,9 @@ def apply_impl(self, input):
934934
elif self.combine_fn == tl.standard._argmax_combine_tie_break_left:
935935
return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=np.argmax)
936936
elif self.combine_fn == tl.standard._elementwise_max:
937-
return self.min_max(input[0], val_reduce_op=np.max, idx_reduce_op=None)
937+
return self.min_max(input[0], val_reduce_op=np.nanmax, idx_reduce_op=None)
938938
elif self.combine_fn == tl.standard._elementwise_min:
939-
return self.min_max(input[0], val_reduce_op=np.min, idx_reduce_op=None)
939+
return self.min_max(input[0], val_reduce_op=np.nanmin, idx_reduce_op=None)
940940
elif self.combine_fn == tl.standard._sum_combine:
941941
return self.sum(input[0])
942942
else:

python/triton_kernels/tests/test_matmul.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,6 @@ def test_op(m, n, k, split_k, do_gather, do_scatter, fused_scatter, has_y_gammas
257257
if split_k > 1:
258258
pytest.skip("splitK hasn't been fully tested on AMD GPU.")
259259

260-
if is_hip_cdna3() and ("float8_e4m3fn" in (weight_dtype_str, act_dtype_str)):
261-
pytest.skip("float8_e4m3fn hasn't been fully tested on AMD CDNA3 platform.")
262-
263260
if "float8_e4m3fnuz" in (weight_dtype_str, act_dtype_str) and not is_hip_cdna3():
264261
pytest.skip("float8_e4m3fnuz only tested on AMD CDNA3 Platform")
265262

python/triton_kernels/tests/test_mxfp.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
upcast_from_mxfp_torch,
2323
)
2424
from triton_kernels.testing import assert_close, assert_equal
25-
from triton_kernels.target_info import is_hip, is_hip_cdna3
25+
from triton_kernels.target_info import is_hip
2626

2727

2828
def dtype_str_to_torch(dtype_str: str) -> torch.dtype:
@@ -146,8 +146,6 @@ def test_mxfp_casting(
146146
if is_hip():
147147
if swizzle_value is not None or swizzle_scale is not None:
148148
pytest.skip("Other swizzling patterns are not supported by AMD GPU")
149-
if quant_dtype == 'float8_e4m3fn' and is_hip_cdna3():
150-
pytest.skip("float8_e4m3fn cast hasn't been fully tested on AMD CDNA3")
151149

152150
swizzle_axis = swizzle_axis if (swizzle_value or swizzle_scale) else None
153151
quant_torch_type = dtype_str_to_torch(quant_dtype)

0 commit comments

Comments
 (0)