Skip to content

Commit 92dd27c

Browse files
Merge commit '54c840b06443fd28357d81acb605ef16ba4e4e1a'
2 parents 3dd760b + 54c840b commit 92dd27c

File tree

29 files changed

+1172
-341
lines changed

29 files changed

+1172
-341
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ jobs:
404404
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
405405
fi
406406
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
407+
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
407408
cd python/test/unit
408409
pytest --capture=tee-sys -rfs -n 16 language runtime \
409410
--ignore=language/test_line_info.py \

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ jobs:
402402
echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1
403403
fi
404404
pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py
405+
pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py
405406
cd python/test/unit
406407
pytest --capture=tee-sys -rfs -n 16 language runtime \
407408
--ignore=language/test_line_info.py \

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1142,7 +1142,8 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
11421142
bool isHopper() const;
11431143

11441144
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> shape,
1145-
int bitwidth, int opIdx) const;
1145+
int bitwidth, int kWidth,
1146+
int opIdx) const;
11461147
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
11471148

11481149
bool supportReduction() const {

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
2828
"TRITON_DISABLE_LINE_INFO",
2929
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
3030
"TRITON_ENABLE_LLVM_DEBUG",
31+
"TRITON_HIP_STREAM_PREFETCH",
3132
"TRITON_LLVM_DEBUG_ONLY",
3233
"USE_IR_LOC",
3334
"NVPTX_ENABLE_DUMP",

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -953,7 +953,7 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
953953
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
954954
if (mma.isAmpere() || mma.isHopper()) {
955955
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
956-
auto rep = mma.getRepForOperand(shape, bitwidth, idx);
956+
auto rep = mma.getRepForOperand(shape, bitwidth, kWidth, idx);
957957
auto sizePerThread = getSizePerThread();
958958
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
959959
if (rank == 3)
@@ -2018,14 +2018,18 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
20182018

20192019
SmallVector<int64_t>
20202020
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
2021-
int opIdx) const {
2021+
int kWidth, int opIdx) const {
20222022
auto rank = shape.size();
20232023
auto warpsPerCTA = getWarpsPerCTA();
20242024

20252025
// {batch, m, n, k}
20262026
// Hopper path never uses the n value, since this method is only invoked
20272027
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
2028-
SmallVector<int> shapePerWarp = {1, 16, 8, 4 * 64 / bitwidth};
2028+
// TODO: rep per operand is not accurate for Hopper. It is currently done that
2029+
// way to allow us to get the correct total number of elements. this will be
2030+
// fixed when moving to linear layout.
2031+
SmallVector<int> shapePerWarp = {
2032+
1, 16, 8, isHopper() ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
20292033
int numRepBatch =
20302034
rank == 3
20312035
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))

lib/Dialect/TritonGPU/Transforms/OptimizeAccumulatorInit.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,7 @@ void setUseAccFlag(Operation *op, Value useAcc) {
3838
}
3939

4040
bool isConstantZeroTensor(Value v) {
41-
auto constOp = v.getDefiningOp<arith::ConstantOp>();
42-
if (!constOp)
43-
return false;
44-
auto splat = mlir::dyn_cast<SplatElementsAttr>(constOp.getValue());
45-
if (!splat)
46-
return false;
47-
return splat.getSplatValue<FloatAttr>().getValue().convertToFloat() == 0.0f;
41+
return (matchPattern(v, m_Zero()) || matchPattern(v, m_AnyZeroFloat()));
4842
}
4943

5044
std::optional<std::pair<Operation *, int>> findZeroInitOp(Value accUse,

python/test/regression/test_cast_matmul.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import triton
1313
import triton.runtime as tr
1414
import triton.language as tl
15-
from triton._internal_testing import is_hip_mi300, is_cuda
15+
from triton._internal_testing import is_hip_mi300, is_cuda, is_hip
1616

1717
input_dtypes = ["float16", "float32", "float64"]
1818
if is_cuda():
@@ -78,19 +78,22 @@ def matmul_kernel(A, B, C, M, N, K, #
7878
tl.store(C, acc, mask=mask)
7979

8080

81-
@pytest.mark.parametrize("M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype",
82-
[(M, K, N, BLOCK_K, w, x, o) #
81+
@pytest.mark.parametrize("M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype",
82+
[(M, K, N, BLOCK_K, BLOCK_M, w, x, o) #
8383
for BLOCK_K in [16, 32] #
84+
for BLOCK_M in [16, 64] #
8485
for (M, K, N) in [(128, 128, 128), (768, 768, 1024)] #
8586
for w in input_dtypes
8687
for x in input_dtypes #
8788
for o in out_dtypes])
88-
def test_cast_matmul(M, K, N, BLOCK_K, w_dtype, x_dtype, out_dtype, device):
89+
def test_cast_matmul(M, K, N, BLOCK_K, BLOCK_M, w_dtype, x_dtype, out_dtype, device):
8990
if x_dtype == w_dtype:
9091
pytest.xfail("skip the same input dtype")
9192
if device == "xpu" and "float64" in (w_dtype,
9293
x_dtype) and not tr.driver.active.get_current_target().arch['has_fp64']:
9394
pytest.xfail("float64 not supported on current xpu hardware")
95+
if is_hip() and BLOCK_M == 64 and w_dtype in ["float8_e5m2", "float8_e4m3fnuz"]:
96+
pytest.skip("skip due to bug on HIP path")
9497
x_dtype: torch.dtype = getattr(torch, x_dtype)
9598
w_dtype: torch.dtype = getattr(torch, w_dtype)
9699

@@ -112,7 +115,7 @@ def init_tensor(dtype, shape):
112115
out_triton = torch.empty((M, N), device=device, dtype=torch_dtype)
113116

114117
# launch kernel
115-
block_m, block_n, block_k = 16, 16, BLOCK_K
118+
block_m, block_n, block_k = BLOCK_M, 16, BLOCK_K
116119
grid = ((triton.cdiv(M, block_m) * triton.cdiv(N, block_n)), 1)
117120

118121
matmul_kernel[grid](

python/triton/runtime/interpreter.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(self, data, dtype):
2121
'''
2222
data: numpy array
2323
dtype: triton type, either pointer_type or scalar_type.
24-
we don't store block_type here because the shape information is already availale in the data field
24+
we don't store block_type here because the shape information is already available in the data field
2525
attr: a dictionary of attributes
2626
'''
2727
self.data = data
@@ -46,24 +46,23 @@ def set_attr(self, key, value):
4646

4747
class BlockPointerHandle:
4848

49-
def __init__(self, base, shape, strides, offsets, tensor_shape, order):
49+
def __init__(self, base, shape, strides, offsets, block_shape, order):
5050
self.base = base
5151
self.shape = shape
5252
self.strides = strides
5353
self.offsets = offsets
54-
self.tensor_shape = tensor_shape
54+
self.block_shape = block_shape
5555
self.order = order
5656

5757
def materialize_pointers(self, boundary_check):
5858
dtype_tt = self.base.get_element_ty()
5959
n_bytes = dtype_tt.primitive_bitwidth // 8
60-
tensor_shape = self.tensor_shape
61-
ptrs = np.broadcast_to(self.base.data, self.tensor_shape)
62-
masks = np.ones(self.tensor_shape, dtype=bool)
63-
for dim in range(len(tensor_shape)):
64-
bcast_dims = [1] * len(tensor_shape)
65-
bcast_dims[dim] = tensor_shape[dim]
66-
off = (self.offsets[dim].data + np.arange(tensor_shape[dim])).reshape(bcast_dims)
60+
ptrs = np.broadcast_to(self.base.data, self.block_shape)
61+
masks = np.ones(self.block_shape, dtype=bool)
62+
for dim in range(len(self.block_shape)):
63+
bcast_dims = [1] * len(self.block_shape)
64+
bcast_dims[dim] = self.block_shape[dim]
65+
off = (self.offsets[dim].data + np.arange(self.block_shape[dim])).reshape(bcast_dims)
6766
ptrs = ptrs + (n_bytes * off * self.strides[dim].data).astype(np.uint64)
6867
if dim in boundary_check:
6968
masks = np.logical_and(masks, off < self.shape[dim].data)
@@ -655,17 +654,17 @@ def create_barrier(self):
655654
# Triton's barrier applies to each program in a grid, so it's a no-op in the interpreter
656655
pass
657656

658-
def create_make_block_ptr(self, base, shape, strides, offsets, tensor_shape, order):
657+
def create_make_block_ptr(self, base, shape, strides, offsets, block_shape, order):
659658
# Create new offsets to avoid modifying the original
660659
new_offsets = [offset.clone() for offset in offsets]
661-
return BlockPointerHandle(base, shape, strides, new_offsets, tensor_shape, order)
660+
return BlockPointerHandle(base, shape, strides, new_offsets, block_shape, order)
662661

663662
def create_advance(self, ptr, offsets):
664663
if len(ptr.offsets) != len(offsets):
665664
raise ValueError("len(ptr.offsets) != len(offsets)")
666665
# Create new offsets to avoid modifying the original
667666
new_offsets = [offset.clone() for offset in ptr.offsets]
668-
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.tensor_shape, ptr.order)
667+
ret = BlockPointerHandle(ptr.base, ptr.shape, ptr.strides, new_offsets, ptr.block_shape, ptr.order)
669668
for i in range(len(offsets)):
670669
ret.offsets[i].data += offsets[i].data
671670
return ret
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics
2+
3+
// Invalid size
4+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
5+
tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
6+
// expected-error @+1 {{sizes [256, 2] must be a multiple of shapePerCTATile [256, 16]}}
7+
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x2xi32, #blocked1>
8+
tt.return
9+
}
10+
11+
// -----
12+
13+
// Invalid zero source dimension
14+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
15+
tt.func @invalid_size_input(%arg0: tensor<256x0xi32, #blocked1> {tt.divisibility = 16 : i32}) {
16+
// expected-error @+1 {{source tensor dimension size zero at dimension 1}}
17+
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x0xi32, #blocked1> to tensor<256x16xi32, #blocked1>
18+
tt.return
19+
}
20+
21+
// -----
22+
23+
// Invalid zero result dimension
24+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
25+
tt.func @invalid_size_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
26+
// expected-error @+1 {{result tensor dimension size zero at dimension 1}}
27+
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x0xi32, #blocked1>
28+
tt.return
29+
}
30+
31+
// -----
32+
33+
// Invalid offset, not multiple of shapePerTile
34+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
35+
tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
36+
// expected-error @+1 {{offset [0, 5] must be a multiple of shapePerCTATile [256, 16]}}
37+
%1 = amdgpu.extract_slice %arg0 [0,5] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
38+
tt.return
39+
}
40+
41+
// -----
42+
43+
// Invalid offset, out of bounds for dimension
44+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
45+
tt.func @invalid_offset_input(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
46+
// expected-error @+1 {{invalid offset 128 at dimension 1}}
47+
%1 = amdgpu.extract_slice %arg0 [0,128] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
48+
tt.return
49+
}
50+
51+
// -----
52+
53+
// Invalid result layout
54+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
55+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
56+
tt.func @invalid_result_layout(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
57+
// expected-error @+1 {{result layout must match source layout}}
58+
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked2>
59+
tt.return
60+
}
61+
62+
// -----
63+
64+
// Invalid result element type
65+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
66+
tt.func @invalid_result_element_type(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
67+
// expected-error @+1 {{result element type must match source element type}}
68+
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi64, #blocked1>
69+
tt.return
70+
}
71+
72+
// -----
73+
74+
// Invalid result rank
75+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
76+
tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
77+
// expected-error @+1 {{result rank must be equal to source rank}}
78+
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x16x2xi32, #blocked1>
79+
tt.return
80+
}
81+
82+
// -----
83+
84+
// Invalid result shape
85+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
86+
tt.func @invalid_result_rank(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}) {
87+
// expected-error @+1 {{result shape cannot be larger than input shape at dimension 1}}
88+
%1 = amdgpu.extract_slice %arg0 [0,0] : tensor<256x128xi32, #blocked1> to tensor<256x256xi32, #blocked1>
89+
tt.return
90+
}
91+
92+
// -----
93+
94+
// Invalid rank
95+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
96+
tt.func @invalid_rank(%arg0: tensor<256x128x2xi32, #blocked1> {tt.divisibility = 16 : i32}) {
97+
// expected-error @+1 {{currently only 2D tensors are supported}}
98+
%1 = amdgpu.extract_slice %arg0 [0,0,0] : tensor<256x128x2xi32, #blocked1> to tensor<256x16x2xi32, #blocked1>
99+
tt.return
100+
}
101+
102+
// -----
103+
104+
// Invalid non static offset
105+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
106+
tt.func @invalid_non_static_offset(%arg0: tensor<256x128xi32, #blocked1> {tt.divisibility = 16 : i32}, %arg1: i32) {
107+
// expected-error @+2 {{expected ']'}}
108+
// expected-error @+1 {{expected integer value}}
109+
%2 = amdgpu.extract_slice %arg0 [%arg1, 0] : tensor<256x128xi32, #blocked1> to tensor<256x16xi32, #blocked1>
110+
tt.return
111+
}

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,3 +293,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
293293
tt.return
294294
}
295295
}
296+
297+
// -----
298+
299+
#mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 128, 16]}>
300+
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32} {
301+
// CHECK-LABEL: test_fp8_to_fp16_dot_operand
302+
// CHECK-COUNT-16: cvt.rn.f16x2.e5m2x2
303+
tt.func @test_fp8_to_fp16_dot_operand(%arg: tensor<128x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>) {
304+
%r = tt.fp_to_fp %arg : tensor<128x32xf8E5M2, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> -> tensor<128x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
305+
tt.return
306+
}
307+
}

0 commit comments

Comments
 (0)