Skip to content

Commit 9c52bc3

Browse files
Merge commit '256ef34ca707a0c9675bafbbad2d89ecca3c8e94'
2 parents 70a4ddf + 256ef34 commit 9c52bc3

File tree

15 files changed

+100
-60
lines changed

15 files changed

+100
-60
lines changed

.github/workflows/integration-tests.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ jobs:
460460
- name: Install brew dependencies
461461
run: |
462462
brew update
463-
brew install ccache llvm
463+
brew install ccache llvm@19 lld
464464
- name: Compute cache keys
465465
id: cache-key
466466
run: |

.github/workflows/integration-tests.yml.in

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ jobs:
439439
- name: Install brew dependencies
440440
run: |
441441
brew update
442-
brew install ccache llvm
442+
brew install ccache llvm@19 lld
443443

444444
- *compute-cache-keys-step
445445
- *cache-build-dependencies-step

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ export_executable_symbols_for_plugins(triton-llvm-opt)
102102
add_llvm_executable(triton-tensor-layout triton-tensor-layout.cpp PARTIAL_SOURCES_INTENDED)
103103
target_link_libraries(triton-tensor-layout PRIVATE
104104
TritonGPUIR
105+
TritonNvidiaGPUIR
105106
${triton_libs}
107+
${conversion_libs}
108+
${dialect_libs}
109+
TritonTestAnalysis
106110
)
107111

108112
add_llvm_executable(triton-translate

bin/triton-tensor-layout.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
#include "RegisterTritonDialects.h"
2+
13
#include "mlir/AsmParser/AsmParser.h"
24
#include "mlir/AsmParser/AsmParserState.h"
35
#include "mlir/IR/MLIRContext.h"
46

57
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
8+
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
69

710
#include "llvm/Support/CommandLine.h"
811
#include "llvm/Support/ErrorOr.h"
@@ -114,7 +117,7 @@ LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename,
114117
return failure();
115118
}
116119

117-
auto printLambda = [&](StringRef name, Attribute attr) {
120+
auto printLambda = [&](StringRef name, mlir::Attribute attr) {
118121
ss << "Print layout attribute: #" << name << " = " << attr << "\n";
119122

120123
auto rankedTensorTy = RankedTensorType::get(
@@ -155,7 +158,7 @@ LogicalResult printLayoutFromString(MLIRContext *context,
155158
if (layoutAttrStr.empty())
156159
return success();
157160

158-
Attribute layout = parseAttribute(layoutAttrStr, context);
161+
mlir::Attribute layout = parseAttribute(layoutAttrStr, context);
159162
if (!layout) {
160163
llvm::errs() << "Invalid layout attribute: " << layoutAttrStr << "\n";
161164
return failure();
@@ -178,8 +181,7 @@ int main(int argc, char **argv) {
178181
cl::ParseCommandLineOptions(argc, argv, "tensor layout printer\n");
179182

180183
DialectRegistry registry;
181-
// Register all dialects that can print tensor layout.
182-
registry.insert<triton::gpu::TritonGPUDialect>();
184+
registerTritonDialects(registry);
183185

184186
MLIRContext ctx(registry);
185187
ctx.loadAllAvailableDialects();
@@ -189,7 +191,7 @@ int main(int argc, char **argv) {
189191
return 1;
190192
}
191193

192-
Type parsedTy = parseType(TensorStr, &ctx);
194+
mlir::Type parsedTy = parseType(TensorStr, &ctx);
193195
if (!parsedTy) {
194196
llvm::errs() << "Fail to parse the tensor type argument: " << TensorStr
195197
<< "\n";

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,32 @@ getThreadsPerWarpWithUniqueData(Attribute layout,
7575
SmallVector<unsigned>
7676
getWarpsPerCTAWithUniqueData(Attribute layout, ArrayRef<int64_t> tensorShape);
7777

78+
// Returns the dimensions of the tensor from minor (fast-varying) to
79+
// major (slow-varying). For blocked, mma, and dotOperand layouts,
80+
// though the elements are in registers, the order refers to memory
81+
// layout of the original tensor in global memory.
82+
// For shared Layout, the order refers to which dimension of the original tensor
83+
// is contiguous in shared memory.
84+
SmallVector<unsigned> getOrder(Attribute layout);
85+
86+
// Returns the dimensions along which warpId's are distributed.
87+
// warpsPerCTA only tells the warp layout in the CTA, e.g. warpsPerCTA = [2, 4]
88+
// tells there are 2 warps along dim0 and 4 warps along dim1.
89+
// warpOrder tells the specific order when distributing warp IDs.
90+
// E.g. warpOrder = [0, 1] means the warp IDs are distributed as follows
91+
// [warp0 warp2 warp4 warp6]
92+
// [warp1 warp3 warp5 warp7]
93+
// Note that in most cases, getWarpOrder and getOrder return the same results.
94+
// But this is not guaranteed.
7895
SmallVector<unsigned> getWarpOrder(Attribute layout);
7996

80-
SmallVector<unsigned> getOrder(Attribute layout);
97+
// Returns the dimensions along which threadId's are distributed.
98+
// Similar to warpOrder, threadOrder is necessary to tell the specific thread
99+
// distribution in the warp.
100+
// Note that, in most cases, getThreadOrder and getOrder return the same
101+
// results. But this is not guaranteed. One exception is mfma.transposed layout,
102+
// in which getOrder returns [1, 0] but getThreadOrder returns [0, 1].
103+
SmallVector<unsigned> getThreadOrder(Attribute layout);
81104

82105
CTALayoutAttr getCTALayout(Attribute layout);
83106

lib/Analysis/Utility.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ SmallVector<unsigned> getParentOrder(Attribute layout) {
3838
if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
3939
return getParentOrder(sliceEncoding.getParent());
4040
}
41-
return getOrder(layout);
41+
return getThreadOrder(layout);
4242
}
4343

4444
} // namespace
@@ -77,7 +77,7 @@ unsigned ReduceOpHelper::getThreadOffsetOnReductionAxis() {
7777
threadOffset = threadsPerWarp[sliceLayout.getDim()];
7878
} else {
7979
auto threadsPerWarp = getThreadsPerWarp(srcLayout);
80-
auto order = getOrder(srcLayout);
80+
auto order = getThreadOrder(srcLayout);
8181
for (unsigned i = 0; i < order.size(); i++) {
8282
if (order[i] == axis)
8383
break;

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ using namespace mlir::triton;
99
using ::mlir::LLVM::delinearize;
1010
using ::mlir::LLVM::linearize;
1111
using ::mlir::triton::gpu::getOrder;
12+
using ::mlir::triton::gpu::getThreadOrder;
1213
using ::mlir::triton::gpu::getTotalElemsPerThread;
1314

1415
namespace {
@@ -271,7 +272,7 @@ struct ReduceOpConversion
271272

272273
auto threadsPerWarp =
273274
triton::gpu::getThreadsPerWarpWithUniqueData(srcLayout, srcShape);
274-
auto order = getOrder(srcLayout);
275+
auto order = getThreadOrder(srcLayout);
275276
SmallVector<Value> multiDimLaneId =
276277
delinearize(rewriter, loc, laneId, threadsPerWarp, order);
277278
Value laneIdAxis = multiDimLaneId[axis];

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,14 +259,6 @@ SmallVector<unsigned> getOrder(Attribute layout) {
259259
auto rank = distributedLayout.getWarpsPerCTA().size();
260260
SmallVector<unsigned> order(rank);
261261
std::iota(order.rbegin(), order.rend(), 0);
262-
auto mfmaLayout = dyn_cast<AMDMfmaEncodingAttr>(layout);
263-
if (!mfmaLayout)
264-
return order;
265-
// For transposed MFMA layouts, we swap M and N dimensions, which is
266-
// always the first two in order; as we can have an optional batch
267-
// dimension following them.
268-
if (mfmaLayout.getIsTransposed())
269-
std::swap(order[0], order[1]);
270262
return order;
271263
}
272264
if (auto dotLayout = dyn_cast<DotOperandEncodingAttr>(layout)) {
@@ -293,6 +285,14 @@ SmallVector<unsigned> getOrder(Attribute layout) {
293285
return {};
294286
};
295287

288+
SmallVector<unsigned> getThreadOrder(Attribute layout) {
289+
if (auto distributedLayout = mlir::dyn_cast<DistributedEncodingTrait>(layout))
290+
return distributedLayout.getThreadOrder();
291+
else
292+
llvm::report_fatal_error("Unimplemented usage of getThreadOrder");
293+
return {};
294+
};
295+
296296
CTALayoutAttr getCTALayout(Attribute layout) {
297297
if (auto distributedLayout =
298298
mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
@@ -1557,7 +1557,10 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getWarpOrder() const {
15571557
return ::getWarpOrder(*this);
15581558
}
15591559
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadOrder() const {
1560-
return ::getOrder(*this);
1560+
auto order = ::getOrder(*this);
1561+
if (getIsTransposed())
1562+
std::swap(order[0], order[1]);
1563+
return order;
15611564
}
15621565
SmallVector<unsigned> AMDMfmaEncodingAttr::getThreadsPerWarp() const {
15631566
unsigned rows, cols;

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
507507
{{kRegister, {{0, 1}, {0, 2}, {0, 8}, /*gap*/ {0, 16}}},
508508
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, {16, 0}, /*gap*/ {0, 4}}}},
509509
{outDimNames[order[0]], outDimNames[order[1]]});
510+
// For mfma.transposed layout, the element ownership among threads are
511+
// "transposed" within each warp.
512+
if (getIsTransposed())
513+
tileLayout = LinearLayout(
514+
{{kRegister, {{1, 0}, {2, 0}, {8, 0}, /*gap*/ {16, 0}}},
515+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, /*gap*/ {4, 0}}}},
516+
{outDimNames[order[0]], outDimNames[order[1]]});
510517
} else {
511518
assert(getMDim() == 16);
512519
// For mfma with 16x16 output, each of the 64 threads holds 4 elements.
@@ -521,6 +528,13 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
521528
{{kRegister, {{0, 1}, {0, 2}}},
522529
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}, /*gap*/ {0, 4}, {0, 8}}}},
523530
{outDimNames[order[0]], outDimNames[order[1]]});
531+
// For mfma.transposed layout, the element ownership among threads are
532+
// "transposed" within each warp.
533+
if (getIsTransposed())
534+
tileLayout = LinearLayout(
535+
{{kRegister, {{1, 0}, {2, 0}}},
536+
{kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, /*gap*/ {4, 0}, {8, 0}}}},
537+
{outDimNames[order[0]], outDimNames[order[1]]});
524538
}
525539
if (hasBatchDim) {
526540
assert(order[2] == 0);

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,10 @@ struct LoadStoreConversionBase {
119119
return axisAnalysisPass.getMaskAlignment(mask);
120120
}
121121

122+
unsigned getPtrAlignment(Value ptr) const {
123+
return axisAnalysisPass.getPtrAlignment(ptr);
124+
}
125+
122126
protected:
123127
const AMD::TargetInfo &targetInfo;
124128
ModuleAxisInfoAnalysis &axisAnalysisPass;
@@ -193,7 +197,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
193197
// vectorized iteration through all the pointer/mask/other elements
194198
const int valueElemNBits =
195199
std::max(8u, valueElemTy.getIntOrFloatBitWidth());
200+
const size_t valueElemNBytes = valueElemNBits / 8;
196201
const int numVecs = numElems / vec;
202+
int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes;
197203

198204
auto cacheMod = op.getCache();
199205
SmallVector<Value> loadedVals;
@@ -230,8 +236,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
230236
falseVal = v;
231237
}
232238

233-
auto loadVal =
234-
llLoad(rewriter, loc, ptr, vecTy, pred, falseVal, cacheMod);
239+
Value loadVal = llLoad(rewriter, loc, ptr, vecTy, pred, falseVal,
240+
ptrAlignmentBytes, cacheMod);
235241
for (size_t ii = 0; ii < vec; ++ii) {
236242
Value vecIdx = createIndexAttrConstant(
237243
rewriter, loc, this->getTypeConverter()->getIndexType(), ii % vec);
@@ -294,9 +300,10 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
294300
vec = std::min(vec, maskAlign);
295301
}
296302

297-
const size_t dtsize =
298-
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
299-
const size_t valueElemNBits = dtsize * 8;
303+
const size_t valueElemNBits =
304+
std::max<int>(8, valueElemTy.getIntOrFloatBitWidth());
305+
const size_t valueElemNBytes = valueElemNBits / 8;
306+
int64_t ptrAlignmentBytes = getPtrAlignment(ptr) * valueElemNBytes;
300307

301308
auto cacheMod = op.getCache();
302309
const int numVecs = elemsPerThread / vec;
@@ -328,7 +335,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
328335
rewriter, loc, this->getTypeConverter()->getIndexType(), s);
329336
storeVal = insert_element(vecTy, storeVal, otherElem, indexVal);
330337
}
331-
llStore(rewriter, loc, ptr, storeVal, pred, cacheMod);
338+
llStore(rewriter, loc, ptr, storeVal, pred, ptrAlignmentBytes, cacheMod);
332339
} // end vec
333340
rewriter.eraseOp(op);
334341
return success();

0 commit comments

Comments
 (0)