Skip to content

Commit aa1436d

Browse files
Merge commit '1df64d1aaf9ecd74124ccb503d5fe1016a8f92cf'
2 parents 9bda03d + 1df64d1 commit aa1436d

File tree

20 files changed

+122
-70
lines changed

20 files changed

+122
-70
lines changed

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ export_executable_symbols_for_plugins(triton-llvm-opt)
102102
add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
103103
target_link_libraries(triton-tensor-layout PRIVATE
104104
TritonGPUIR
105+
TritonNvidiaGPUIR
105106
${triton_libs}
107+
${conversion_libs}
108+
${dialect_libs}
109+
TritonTestAnalysis
106110
)
107111

108112
add_llvm_executable(triton-translate

bin/triton-tensor-layout.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
#include "RegisterTritonDialects.h"
2+
13
#include "mlir/AsmParser/AsmParser.h"
24
#include "mlir/AsmParser/AsmParserState.h"
35
#include "mlir/IR/MLIRContext.h"
46

57
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
8+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
69

710
#include "llvm/Support/CommandLine.h"
811
#include "llvm/Support/ErrorOr.h"
@@ -114,7 +117,7 @@ LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
114117
return failure();
115118
}
116119

117-
auto printLambda = [&](StringRef name, Attribute attr) {
120+
auto printLambda = [&](StringRef name, mlir::Attribute attr) {
118121
ss << "Print layout attribute: #" << name << " = " << attr << "\n";
119122

120123
auto rankedTensorTy = RankedTensorType::get(
@@ -155,7 +158,7 @@ LogicalResult printLayoutFromString(MLIRContext *context,
155158
if (layoutAttrStr.empty())
156159
return success();
157160

158-
Attribute layout = parseAttribute(layoutAttrStr, context);
161+
mlir::Attribute layout = parseAttribute(layoutAttrStr, context);
159162
if (!layout) {
160163
llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n";
161164
return failure();
@@ -178,8 +181,7 @@ int main(int argc, char **argv) {
178181
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n");
179182

180183
DialectRegistry registry;
181-
// Register all dialects that can print tensor layout.
182-
registry.insert<triton::gpu::TritonGPUDialect>();
184+
registerTritonDialects(registry);
183185

184186
MLIRContext ctx(registry);
185187
ctx.loadAllAvailableDialects();
@@ -189,7 +191,7 @@ int main(int argc, char **argv) {
189191
return 1;
190192
}
191193

192-
Type parsedTy = parseType(TensorStr, &ctx);
194+
mlir::Type parsedTy = parseType(TensorStr, &ctx);
193195
if (!parsedTy) {
194196
llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr
195197
<< "\n";

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout,
7575
SmallVector<unsigned>
7676
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
7777

78+
// Returns the dimensions of the tensor from minor (fast-varying) to
79+
// major (slow-varying). For blocked, mma, and dotOperand layouts,
80+
// though the elements are in registers, the order refers to memory
81+
// layout of the original tensor in global memory.
82+
// For shared Layout, the order refers to which dimension of the original tensor
83+
// is contiguous in shared memory.
84+
SmallVector<unsigned> getOrder(Attribute layout);
85+
86+
// Returns the dimensions along which warpId's are distributed.
87+
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
88+
// tells there are 2 warps along dim0 and 4 warps along dim1.
89+
// warpOrder tells the specific order when distributing warp IDs.
90+
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
91+
// [warp0 warp2 warp4 warp6]
92+
// [warp1 warp3 warp5 warp7]
93+
// Note that in most cases, getWarpOrder and getOrder return the same results.
94+
// But this is not guaranteed.
7895
SmallVector<unsigned> getWarpOrder(Attribute layout);
7996

80-
SmallVector<unsigned> getOrder(Attribute layout);
97+
// Returns the dimensions along which threadId's are distributed.
98+
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
99+
// distribution in the warp.
100+
// Note that, in most cases, getThreadOrder and getOrder return the same
101+
// results. But this is not guaranteed. One exception is mfma.transposed layout,
102+
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
103+
SmallVector<unsigned> getThreadOrder(Attribute layout);
81104

82105
CTALayoutAttr getCTALayout(Attribute layout);
83106

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
3838
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
3939
return getParentOrder(sliceEncoding.getParent());
4040
}
41-
return getOrder(layout);
41+
return getThreadOrder(layout);
4242
}
4343

4444
} // namespace
@@ -77,7 +77,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
7777
threadOffset = threadsPerWarp[sliceLayout.getDim()];
7878
} else {
7979
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
80-
auto order = getOrder(srcLayout);
80+
auto order = getThreadOrder(srcLayout);
8181
for (unsigned i = 0; i < order.size(); i++) {
8282
if (order[i] == axis)
8383
break;

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using namespace mlir::triton;
99
using ::mlir::LLVM::delinearize;
1010
using ::mlir::LLVM::linearize;
1111
using ::mlir::triton::gpu::getOrder;
12+
using ::mlir::triton::gpu::getThreadOrder;
1213
using ::mlir::triton::gpu::getTotalElemsPerThread;
1314

1415
namespace {
@@ -271,7 +272,7 @@ struct ReduceOpConversion
271272

272273
auto threadsPerWarp =
273274
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
274-
auto order = getOrder(srcLayout);
275+
auto order = getThreadOrder(srcLayout);
275276
SmallVector<Value> multiDimLaneId =
276277
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
277278
Value laneIdAxis = multiDimLaneId[axis];

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,14 +259,6 @@ SmallVector<unsigned> getOrder(Attribute layout) {
259259
auto rank = distributedLayout.getWarpsPerCTA().size();
260260
SmallVector<unsigned> order(rank);
261261
std::iota(order.rbegin(), order.rend(), 0);
262-
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(layout);
263-
if (!mfmaLayout)
264-
return order;
265-
// For transposed MFMA layouts, we swap M and N dimensions, which is
266-
// always the first two in order; as we can have an optional batch
267-
// dimension following them.
268-
if (mfmaLayout.getIsTransposed())
269-
std::swap(order[0], order[1]);
270262
return order;
271263
}
272264
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
@@ -293,6 +285,14 @@ SmallVector<unsigned> getOrder(Attribute layout) {
293285
return {};
294286
};
295287

288+
SmallVector<unsigned> getThreadOrder(Attribute layout) {
289+
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
290+
return distributedLayout.getThreadOrder();
291+
else
292+
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
293+
return {};
294+
};
295+
296296
CTALayoutAttr getCTALayout(Attribute layout) {
297297
if (auto distributedLayout =
298298
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
@@ -1557,7 +1557,10 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
15571557
return ::getWarpOrder(*this);
15581558
}
15591559
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
1560-
return ::getOrder(*this);
1560+
auto order = ::getOrder(*this);
1561+
if (getIsTransposed())
1562+
std::swap(order[0], order[1]);
1563+
return order;
15611564
}
15621565
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadsPerWarp() const {
15631566
unsigned rows, cols;

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
507507
{{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}},
508508
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}},
509509
{outDimNames[order[0]], outDimNames[order[1]]});
510+
// For mfma.transposed layout, the element ownership among threads are
511+
// "transposed" within each warp.
512+
if (getIsTransposed())
513+
tileLayout = LinearLayout(
514+
{{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}},
515+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}},
516+
{outDimNames[order[0]], outDimNames[order[1]]});
510517
} else {
511518
assert(getMDim() == 16);
512519
// For mfma with 16x16 output, each of the 64 threads holds 4 elements.
@@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
521528
{{kRegister, {{0, 1}, {0, 2}}},
522529
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}},
523530
{outDimNames[order[0]], outDimNames[order[1]]});
531+
// For mfma.transposed layout, the element ownership among threads are
532+
// "transposed" within each warp.
533+
if (getIsTransposed())
534+
tileLayout = LinearLayout(
535+
{{kRegister, {{1, 0}, {2, 0}}},
536+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
537+
{outDimNames[order[0]], outDimNames[order[1]]});
524538
}
525539
if (hasBatchDim) {
526540
assert(order[2] == 0);

python/setup.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,14 @@ def build_extension(self, ext):
429429
cmake_args += [
430430
"-DCMAKE_C_COMPILER=clang",
431431
"-DCMAKE_CXX_COMPILER=clang++",
432-
"-DCMAKE_LINKER=lld",
433-
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
434-
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
435-
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld",
436432
]
433+
if platform.system() != "Darwin":
434+
cmake_args += [
435+
"-DCMAKE_LINKER=lld",
436+
"-DCMAKE_EXE_LINKER_FLAGS=-fuse-ld=lld",
437+
"-DCMAKE_MODULE_LINKER_FLAGS=-fuse-ld=lld",
438+
"-DCMAKE_SHARED_LINKER_FLAGS=-fuse-ld=lld",
439+
]
437440

438441
# Note that asan doesn't work with binaries that use the GPU, so this is
439442
# only useful for tools like triton-opt that don't run code on the GPU.

python/test/unit/hopper/test_experimental_tma.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def kernel(Z, desc, SIZE: tl.constexpr, BYVAL_TMA: tl.constexpr):
5757
@triton.jit
5858
def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
5959
M, N, K, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
60-
BYVAL_TMA: tl.constexpr):
60+
BYVAL_TMA: tl.constexpr, dtype: tl.constexpr):
6161
if not BYVAL_TMA:
6262
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(a_desc_ptr)
6363
tl.extra.cuda.experimental_tensormap_fenceproxy_acquire(b_desc_ptr)
@@ -72,11 +72,11 @@ def matmul_kernel_tma(a_desc_ptr, b_desc_ptr, c_desc_ptr, #
7272
offs_k = 0
7373
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
7474
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
75-
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], tl.float16)
76-
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], tl.float16)
75+
a = tl._experimental_descriptor_load(a_desc_ptr, [offs_am, offs_k], [BLOCK_SIZE_M, BLOCK_SIZE_K], dtype)
76+
b = tl._experimental_descriptor_load(b_desc_ptr, [offs_k, offs_bn], [BLOCK_SIZE_K, BLOCK_SIZE_N], dtype)
7777
accumulator = tl.dot(a, b, acc=accumulator)
7878
offs_k += BLOCK_SIZE_K
79-
accumulator = accumulator.to(tl.float16)
79+
accumulator = accumulator.to(dtype)
8080
tl._experimental_descriptor_store(c_desc_ptr, accumulator, [offs_am, offs_bn])
8181

8282

@@ -101,7 +101,7 @@ def test_experimental_tma_matmul(num_stages, BLOCK_M, BLOCK_N, BLOCK_K, byval_tm
101101
desc_c = create_tma_desc_gmem_ptr(C.data_ptr(), [M, N], [BLOCK_M, BLOCK_N], C.element_size())
102102
kernel = matmul_kernel_tma[(triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), 1,
103103
1)](desc_a, desc_b, desc_c, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, BYVAL_TMA=byval_tma,
104-
num_warps=8, num_stages=num_stages)
104+
num_warps=8, num_stages=num_stages, dtype=tl.float16)
105105
ref_out = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(torch.float16)
106106
torch.testing.assert_close(ref_out, C, rtol=1e-3, atol=1e-3)
107107
if BLOCK_M >= 64 and BLOCK_N >= 64:

python/triton/language/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1613,7 +1613,7 @@ def _experimental_descriptor_load(desc_pointer, offsets, shape, dtype, _builder=
16131613
16141614
This loads a tensor of data based on the descriptor and offsets.
16151615
"""
1616-
type = block_type(dtype, shape)
1616+
type = block_type(_constexpr_to_value(dtype), shape)
16171617
return semantic.descriptor_load(desc_pointer, offsets, "", "", type, _builder)
16181618

16191619

0 commit comments

Comments
 (0)