Skip to content

Commit 4825a43

Browse files
Merge commit '16ce143b54eacf465c5a90a6aabdc9c3a723cb99'
2 parents b96e1c3 + 16ce143 commit 4825a43

File tree

14 files changed

+532
-158
lines changed

14 files changed

+532
-158
lines changed

CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ include(ExternalProject)
1010

1111
set(CMAKE_INCLUDE_CURRENT_DIR ON)
1212

13-
project(triton CXX)
13+
project(triton CXX C)
1414
include(CTest)
1515

1616
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")

include/triton/Analysis/Utility.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,11 @@ bool isBlockedToDotShortcut(RankedTensorType srcTy, RankedTensorType dstTy);
218218
bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
219219
RankedTensorType dstTy);
220220

221+
// Check if MFMA layout can be converted to the dot operand
222+
// layout using warp shuffle.
223+
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
224+
RankedTensorType dstTy);
225+
221226
// TODO: Move utility functions that belong to ConvertLayoutOp to class
222227
// ConvertLayoutOpHelper in the future
223228
bool shouldUseDistSmem(Attribute srcLayout, Attribute dstLayout);

lib/Analysis/Utility.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/IR/Dialect.h"
1212
#include "mlir/IR/Matchers.h"
1313
#include "mlir/Support/LLVM.h"
14+
#include "triton/Conversion/MLIRTypes.h"
1415
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
1516
#include "triton/Dialect/Triton/IR/Dialect.h"
1617
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -639,6 +640,25 @@ bool matchMmaV3AndDotOperandLayout(RankedTensorType srcTy,
639640
return ans;
640641
}
641642

643+
bool matchMFMAAndDotOperandShuffleCase(RankedTensorType srcTy,
644+
RankedTensorType dstTy) {
645+
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(srcTy.getEncoding());
646+
auto dotOperandLayout = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding());
647+
if (!mfmaLayout || !dotOperandLayout)
648+
return false;
649+
650+
// Currently supporting 32x32 and 16x16 FP8 MFMA -> dot operand case
651+
return dotOperandLayout.getParent() == mfmaLayout &&
652+
dotOperandLayout.getOpIdx() == 0 && mfmaLayout.getIsTransposed() &&
653+
dotOperandLayout.getKWidth() == 8 &&
654+
getContigPerThread(mfmaLayout)[1] == 4 &&
655+
((mfmaLayout.getMDim() == 16 && mfmaLayout.getNDim() == 16) ||
656+
(mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32)) &&
657+
triton::type::isFloat8(srcTy.getElementType()) &&
658+
triton::type::isFloat8(dstTy.getElementType()) &&
659+
mfmaLayout.getWarpsPerCTA()[1] == 1;
660+
}
661+
642662
// We get the smallest submap of srcTy^{-1} * dstTy that is not the identity
643663
// under kBlock, kWarp or kLane (in that order). The idea here is that if we
644664
// have a transformation that's the identity on kBlock, we don't need to use
@@ -738,7 +758,10 @@ bool cvtNeedsSharedMemory(RankedTensorType srcTy, RankedTensorType dstTy) {
738758
return !cvtReordersRegisters(srcTy, dstTy) &&
739759
!triton::gpu::intel::isDpasToDotShortcut(srcTy, dstTy) &&
740760
!isBlockedToDotShortcut(srcTy, dstTy) &&
741-
!matchMmaV3AndDotOperandLayout(srcTy, dstTy);
761+
!matchMmaV3AndDotOperandLayout(srcTy, dstTy) &&
762+
// to be removed when generalized warp shuffle conversions
763+
// are ready:
764+
!matchMFMAAndDotOperandShuffleCase(srcTy, dstTy);
742765
}
743766

744767
bool atomicNeedsSharedMemory(Value value) {

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,12 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
409409
return failure();
410410
}
411411

412+
// The following check can be removed when generalized warp shuffle
413+
// conversions are ready:
414+
if (matchMFMAAndDotOperandShuffleCase(srcTy, dstTy)) {
415+
return failure();
416+
}
417+
412418
assert(cvtNeedsSharedMemory(srcTy, dstTy));
413419

414420
SmallVector<Value> inVals =

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,23 +1054,18 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
10541054
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
10551055
return elemsPerThread;
10561056
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
1057-
if (mma.isAmpere() || mma.isHopper()) {
1058-
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
1059-
auto rep = mma.getRepForOperand(shape, bitwidth, kWidth, idx);
1060-
auto sizePerThread = getSizePerThread();
1061-
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
1062-
if (rank == 3)
1063-
elemsPerThread[0] = rep[0];
1064-
elemsPerThread[rank - 2] =
1065-
(idx == 0)
1066-
? rep[1] * sizePerThread[rank - 2]
1067-
: std::max<int>(rep[1] * elemsPerKRep, sizePerThread[rank - 2]);
1068-
elemsPerThread[rank - 1] =
1069-
(idx == 0)
1070-
? std::max<int>(rep[2] * elemsPerKRep, sizePerThread[rank - 1])
1071-
: rep[2] * sizePerThread[rank - 1];
1072-
return elemsPerThread;
1057+
assert(getCTALayout(*this) ==
1058+
CTALayoutAttr::getDefault(getContext(), rank) &&
1059+
"NYI");
1060+
auto sizePerThread = getSizePerThread();
1061+
auto threadsPerWarp = getThreadsPerWarp();
1062+
auto warpsPerCTA = getWarpsPerCTA();
1063+
SmallVector<unsigned> regs;
1064+
for (auto [n, nsize, nThread, nWarp] :
1065+
llvm::zip(shape, sizePerThread, threadsPerWarp, warpsPerCTA)) {
1066+
regs.push_back(std::max<int64_t>(nsize, n / (nThread * nWarp)));
10731067
}
1068+
return regs;
10741069
}
10751070

10761071
if (auto mmaParent = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
@@ -2388,35 +2383,41 @@ NvidiaMmaEncodingAttr::getRepOrderForOperand(int opIdx) const {
23882383
SmallVector<int64_t>
23892384
NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
23902385
int kWidth, int opIdx) const {
2386+
assert(
2387+
kWidth >= 32 / bitwidth &&
2388+
"kWidth must be >= 32 / bitwidth for this function to be well-defined");
23912389
auto rank = shape.size();
2390+
// Broadcast long K
23922391
auto warpsPerCTA = getWarpsPerCTA();
2392+
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
2393+
warpsPerCTA[kDim] = 1;
23932394

2394-
// {batch, m, n, k}
2395-
// Hopper path never uses the n value, since this method is only invoked
2396-
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
2397-
// TODO: rep per operand is not accurate for Hopper. It is currently done that
2398-
// way to allow us to get the correct total number of elements. this will be
2399-
// fixed when moving to linear layout.
2400-
SmallVector<int> shapePerWarp = {
2401-
1, 16, 8, isHopper() ? 4 * 2 * kWidth : 4 * 64 / bitwidth};
2402-
int numRepBatch =
2403-
rank == 3
2404-
? std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]))
2405-
: 1;
2406-
2395+
SmallVector<int> tileSize;
2396+
if (rank == 3) {
2397+
tileSize.push_back(1);
2398+
}
24072399
if (opIdx == 0) {
2408-
return {numRepBatch,
2409-
std::max<int64_t>(1, /*repM=*/shape[rank - 2] /
2410-
(shapePerWarp[1] * warpsPerCTA[rank - 2])),
2411-
std::max<int64_t>(1, /*repK=*/shape[rank - 1] / shapePerWarp[3])};
2400+
// m x k
2401+
tileSize.push_back(16);
2402+
tileSize.push_back(4 * 64 / bitwidth);
24122403
} else {
2413-
assert(opIdx == 1);
2414-
return {
2415-
numRepBatch,
2416-
std::max<int64_t>(1, /*repK=*/shape[rank - 2] / shapePerWarp[3]),
2417-
std::max<int64_t>(1, /*repN=*/shape[rank - 1] /
2418-
(shapePerWarp[2] * warpsPerCTA[rank - 1]))};
2404+
// k x n
2405+
// Hopper path never uses the n value, since this method is only invoked
2406+
// for in-RF (dotOpEnc) operands, but WGMMA only supports in A to be in RF
2407+
// so it's fine if the n is incorrect here
2408+
tileSize.push_back(4 * 64 / bitwidth);
2409+
tileSize.push_back(8);
2410+
}
2411+
2412+
SmallVector<int64_t> numRep;
2413+
// Lezcano: This is odd. Why do we always return a vector of size 3?
2414+
if (rank != 3) {
2415+
numRep.push_back(1);
2416+
}
2417+
for (auto [s, size, warp] : llvm::zip(shape, tileSize, warpsPerCTA)) {
2418+
numRep.push_back(std::max<int64_t>(1, s / (size * warp)));
24192419
}
2420+
return numRep;
24202421
}
24212422

24222423
SmallVector<unsigned>

python/src/llvm.cc

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,6 @@ std::string translateLLVMIRToASM(llvm::Module &module,
139139
{
140140
llvm::raw_string_ostream stream(result);
141141
llvm::buffer_ostream pstream(stream);
142-
for (llvm::Function &f : module.functions())
143-
f.addFnAttr(llvm::Attribute::AlwaysInline);
144142
llvm::legacy::PassManager pass;
145143
// emit
146144
auto fileType = isObject ? llvm::CodeGenFileType::ObjectFile

python/test/unit/hopper/test_experimental_tma.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import triton
55
import triton.language as tl
66
from triton.tools.experimental_descriptor import (create_1d_tma_descriptor, create_2d_tma_descriptor)
7-
from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma
7+
from triton._internal_testing import dtypes_with_bfloat16, numpy_random, to_triton, requires_tma, supports_tma, tma_skip_msg
88

99
from typing import Optional
1010

@@ -29,9 +29,11 @@ def unwrap_tensor(t: torch.Tensor | triton.runtime.jit.TensorWrapper):
2929
tma_dtypes = sorted(set(dtypes_with_bfloat16) - {"int64", "uint64", "float64"})
3030

3131

32-
@requires_tma
3332
@pytest.mark.parametrize("byval_tma", [True, False])
3433
def test_experimetal_descriptor_load(byval_tma):
34+
if not supports_tma(byval_tma):
35+
pytest.skip(tma_skip_msg(byval_tma))
36+
3537
device = "cuda"
3638
SIZE = 128
3739

@@ -82,11 +84,13 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
8284
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
8385

8486

85-
@requires_tma
8687
@pytest.mark.parametrize("num_stages", [1, 4])
8788
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(32, 32, 32), (128, 64, 64), (128, 128, 64), (128, 256, 64)])
8889
@pytest.mark.parametrize("byval_tma", [True, False])
8990
def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tma):
91+
if not supports_tma(byval_tma):
92+
pytest.skip(tma_skip_msg(byval_tma))
93+
9094
device = "cuda"
9195
M, N, K = 8192, 8192, 1024
9296
torch.manual_seed(42)

python/test/unit/language/test_core.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2614,8 +2614,6 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, NUM_PID_N: tl.
26142614
@pytest.mark.parametrize("axis", [0, 1])
26152615
@pytest.mark.parametrize("add_overflow_check", [False, True])
26162616
def test_scan_layouts(M, N, src_layout, axis, add_overflow_check, device, tmp_path: pathlib.Path):
2617-
if add_overflow_check is True and is_hip():
2618-
pytest.skip("overflow check disabled on HIP while fixing issues")
26192617

26202618
overflow_check = """
26212619
%17 = arith.extsi %arg2 : i32 to i64
@@ -2708,8 +2706,6 @@ def test_reduce_layouts(M, N, src_layout, axis, epilogue_kind, dtype_str, add_ov
27082706
pytest.skip("Skipping because tensor shape is smaller than M(f)maLayout instr_shape")
27092707
if is_hip() and isinstance(src_layout, MfmaLayout) and ((M, N) == (128, 128)):
27102708
pytest.skip("Skipping test because it runs out of shared memory")
2711-
if add_overflow_check is True and is_hip():
2712-
pytest.skip("overflow check disabled on HIP while fixing issues")
27132709
if reduce_op == "sum" and dtype_str == "float16" and M * N > 1024:
27142710
pytest.xfail("Skipping sum reduction on float16 due to accuracy issues")
27152711

@@ -5489,21 +5485,11 @@ def test_convertmma2mma(M, N, mma_pair, dtype, device, tmp_path: pathlib.Path):
54895485
pytest.skip("Skip testing MMAv3 on devices with CC < 9")
54905486

54915487
num_warps = np.cumprod(src_layout.warps_per_cta)[-1]
5492-
# TODO(Keren): Remove the intermediate layout once we have resolved the redundantDataMask issue for WGMMA
5493-
warps_per_cta = src_layout.warps_per_cta
5494-
interm = BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [warps_per_cta[0], warps_per_cta[1]], [0, 1], [1, 1],
5495-
[1, 1], [0, 1])
54965488

54975489
def do_test(src_layout, dst_layout):
54985490
layouts = f"""
54995491
#src = {src_layout}
55005492
#dst = {dst_layout}
5501-
#interm = {interm}
5502-
"""
5503-
5504-
conversion = f"""
5505-
%12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
5506-
%13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
55075493
"""
55085494

55095495
ir = layouts + f"""
@@ -5513,6 +5499,7 @@ def do_test(src_layout, dst_layout):
55135499
%0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>>
55145500
%1 = tt.make_range {{end = {N} : i32, start = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>>
55155501
%2 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #src>
5502+
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #dst>
55165503
%4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #triton_gpu.slice<{{dim = 1, parent = #src}}>> -> tensor<{M}x1xi32, #src>
55175504
%5 = arith.muli %4, %cst : tensor<{M}x1xi32, #src>
55185505
%6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{N}xi32, #triton_gpu.slice<{{dim = 0, parent = #src}}>> -> tensor<1x{N}xi32, #src>
@@ -5521,12 +5508,10 @@ def do_test(src_layout, dst_layout):
55215508
%9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #src>
55225509
%10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr<f16>, #src>, tensor<{M}x{N}xi32, #src>
55235510
%11 = tt.load %10 : tensor<{M}x{N}x!tt.ptr<f16>, #src>
5524-
%3 = tt.splat %arg1 : !tt.ptr<f16> -> tensor<{M}x{N}x!tt.ptr<f16>, #interm>
5525-
""" + conversion + f"""
5526-
%15 = triton_gpu.convert_layout %12 : tensor<{M}x{N}xi32, #dst> -> tensor<{M}x{N}xi32, #interm>
5527-
%16 = triton_gpu.convert_layout %13 : tensor<{M}x{N}xf16, #dst> -> tensor<{M}x{N}xf16, #interm>
5528-
%17 = tt.addptr %3, %15 : tensor<{M}x{N}x!tt.ptr<f16>, #interm>, tensor<{M}x{N}xi32, #interm>
5529-
tt.store %17, %16 : tensor<{M}x{N}x!tt.ptr<f16>, #interm>
5511+
%12 = triton_gpu.convert_layout %9 : tensor<{M}x{N}xi32, #src> -> tensor<{M}x{N}xi32, #dst>
5512+
%13 = triton_gpu.convert_layout %11 : tensor<{M}x{N}xf16, #src> -> tensor<{M}x{N}xf16, #dst>
5513+
%14 = tt.addptr %3, %12 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>, tensor<{M}x{N}xi32, #dst>
5514+
tt.store %14, %13 : tensor<{M}x{N}x!tt.ptr<f16>, #dst>
55305515
tt.return
55315516
}}
55325517
}}

python/triton/_internal_testing.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import triton
66
import triton.language as tl
7+
from triton.backends.nvidia.compiler import _path_to_binary
78
import pytest
89

910
from numpy.random import RandomState
@@ -140,8 +141,19 @@ def to_numpy(x):
140141
raise ValueError(f"Not a triton-compatible tensor: {x}")
141142

142143

143-
def supports_tma():
144-
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9
144+
def supports_tma(byval_only=False):
145+
_, cuda_version = _path_to_binary("ptxas")
146+
min_cuda_version = (12, 0) if byval_only else (12, 3)
147+
cuda_version_tuple = tuple(map(int, cuda_version.split(".")))
148+
assert len(cuda_version_tuple) == 2, cuda_version_tuple
149+
return is_cuda() and torch.cuda.get_device_capability()[0] >= 9 and cuda_version_tuple >= min_cuda_version
150+
151+
152+
def tma_skip_msg(byval_only=False):
153+
if byval_only:
154+
return "Requires __grid_constant__ TMA support (NVIDIA Hopper or higher, CUDA 12.0 or higher)"
155+
else:
156+
return "Requires advanced TMA support (NVIDIA Hopper or higher, CUDA 12.3 or higher)"
145157

146158

147-
requires_tma = pytest.mark.skipif(not supports_tma(), reason="Requires TMA support (NVIDIA Hopper or higher)")
159+
requires_tma = pytest.mark.skipif(not supports_tma(), reason=tma_skip_msg())

0 commit comments

Comments
 (0)