Skip to content

Commit ffc2601

Browse files
Merge commit '91302ea36ddc0b61629da975732b2753ef270764'
2 parents 3339986 + 91302ea commit ffc2601

File tree

19 files changed

+684
-691
lines changed

19 files changed

+684
-691
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,5 +253,5 @@ Supported Platforms:
253253
Supported Hardware:
254254

255255
- NVIDIA GPUs (Compute Capability 8.0+)
256-
- AMD GPUs (ROCm 5.2+)
256+
- AMD GPUs (ROCm 6.2+)
257257
- Under development: CPUs

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ using namespace mlir::triton;
5353
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
5454
#define shl(...) rewriter.create<LLVM::ShlOp>(loc, __VA_ARGS__)
5555
#define lshr(...) rewriter.create<LLVM::LShrOp>(loc, __VA_ARGS__)
56+
#define ashr(...) rewriter.create<LLVM::AShrOp>(loc, __VA_ARGS__)
5657
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
5758
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
5859
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -762,8 +762,6 @@ for
762762
];
763763

764764
let extraClassDeclaration = extraDistributedDeclaration # [{
765-
SliceEncodingAttr squeeze(int axis);
766-
767765
SmallVector<unsigned> getContigPerThread() {
768766
// Block encoding is dense stride layout. The elements per thread are contiguous.
769767
return getSizePerThread();

include/triton/Dialect/TritonNvidiaGPU/Transforms/TMAUtilities.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,6 @@ mlir::LogicalResult createTMADesc(mlir::Value tmaPtr,
5252
loc, builder.getI64Type(), builder.getI64IntegerAttr(elemSize));
5353
Value globalStride = builder.template create<arith::MulIOp>(
5454
loc, op.getStrides()[0], elemSizeVal);
55-
// TODO: Workaround for ptxas bug, remove when we update ptxas
56-
Value four = builder.template create<arith::ConstantOp>(
57-
loc, builder.getI64Type(), builder.getI64IntegerAttr(4));
58-
globalStride =
59-
builder.template create<arith::ShRSIOp>(loc, globalStride, four);
6055

6156
int elemTypeEnum;
6257
switch (elemSize) {

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ LogicalResult TransOp::inferReturnTypes(
221221
Attribute retEncoding;
222222
if (argEncoding) {
223223
Dialect &dialect = argEncoding.getDialect();
224-
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
224+
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
225225
if (inferLayoutInterface
226226
->inferTransOpEncoding(argEncoding, order, retEncoding)
227227
.failed()) {
@@ -250,7 +250,7 @@ DotOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
250250
if (aEnc) {
251251
assert(bEnc && retEnc);
252252
Dialect &dialect = retEnc.getDialect();
253-
auto interface = dyn_cast<DialectInferLayoutInterface>(&dialect);
253+
auto interface = cast<DialectInferLayoutInterface>(&dialect);
254254
if (interface->inferDotOpEncoding(aEnc, 0, retEnc, location).failed())
255255
return failure();
256256
if (interface->inferDotOpEncoding(bEnc, 1, retEnc, location).failed())
@@ -331,8 +331,7 @@ inferReduceReturnShape(RankedTensorType argTy, Type retEltTy, int axis,
331331
Attribute retEncoding;
332332
if (argEncoding) {
333333
Dialect &dialect = argEncoding.getDialect();
334-
auto inferLayoutInterface =
335-
dyn_cast<DialectInferLayoutInterface>(&dialect);
334+
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
336335
if (inferLayoutInterface
337336
->inferReduceOpEncoding(argEncoding, axis, retEncoding)
338337
.failed()) {
@@ -565,7 +564,7 @@ LogicalResult ExpandDimsOp::inferReturnTypes(
565564
Attribute retEncoding;
566565
if (argEncoding) {
567566
Dialect &dialect = argEncoding.getDialect();
568-
auto inferLayoutInterface = dyn_cast<DialectInferLayoutInterface>(&dialect);
567+
auto inferLayoutInterface = cast<DialectInferLayoutInterface>(&dialect);
569568
if (inferLayoutInterface
570569
->inferExpandDimsOpEncoding(argEncoding, axis, retEncoding, loc)
571570
.failed())
@@ -604,7 +603,7 @@ LogicalResult ExpandDimsOp::canonicalize(ExpandDimsOp op,
604603
// Infer the encoding of the new expand op, if encodings are present.
605604
Attribute newExpandEnc;
606605
if (auto srcEnc = srcTy.getEncoding()) {
607-
if (dyn_cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
606+
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
608607
->inferExpandDimsOpEncoding(srcEnc, op.getAxis(), newExpandEnc,
609608
op.getLoc())
610609
.failed()) {
@@ -975,7 +974,6 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
975974
assert(isa<RankedTensorType>(operands[1].getType()));
976975

977976
Value lhs = operands[0];
978-
Value rhs = operands[1];
979977
auto srcTy = cast<RankedTensorType>(lhs.getType());
980978

981979
SmallVector<int64_t> retShape(srcTy.getShape());
@@ -984,7 +982,7 @@ JoinOp::inferReturnTypes(MLIRContext *context, std::optional<Location> location,
984982
Attribute srcEnc = srcTy.getEncoding();
985983
Attribute retEnc;
986984
if (srcEnc) {
987-
if (dyn_cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
985+
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
988986
->inferJoinOpEncoding(srcEnc, retEnc, location)
989987
.failed()) {
990988
return failure();
@@ -1017,7 +1015,7 @@ LogicalResult SplitOp::inferReturnTypes(
10171015
Attribute srcEnc = srcTy.getEncoding();
10181016
Attribute retEnc;
10191017
if (srcEnc) {
1020-
if (dyn_cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
1018+
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
10211019
->inferSplitOpEncoding(srcEnc, retEnc, location)
10221020
.failed()) {
10231021
return failure();

0 commit comments

Comments
 (0)