Skip to content

Commit 539f2ed

Browse files
Merge commit '28396a76e099d659c9e2f91a49f077c0065a3c1d'
2 parents 766cab6 + 28396a7 commit 539f2ed

File tree

35 files changed

+168
-221
lines changed

35 files changed

+168
-221
lines changed

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
100100
PatternBenefit benefit);
101101

102102
void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter,
103-
RewritePatternSet &patterns, int numWarps,
103+
RewritePatternSet &patterns,
104104
const TargetInfoBase &targetInfo,
105105
PatternBenefit benefit);
106106

include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,6 @@ template <typename T> class OperationPass;
1212

1313
namespace triton {
1414

15-
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
16-
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
17-
constexpr static char AttrTargetName[] = "ttg.target";
18-
19-
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
20-
2115
// Create the pass with numWarps passed from cl::opt.
2216
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
2317

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ class DialectVerifyTensorLayoutInterface
9191
DialectVerifyTensorLayoutInterface(Dialect *dialect) : Base(dialect) {}
9292

9393
virtual LogicalResult
94-
verifyTensorLayout(Attribute layout, RankedTensorType type, ModuleOp module,
94+
verifyTensorLayout(Attribute layout, RankedTensorType type, Operation *op,
9595
function_ref<InFlightDiagnostic()> emitError) const = 0;
9696
};
9797

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1118,7 +1118,11 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
11181118
}];
11191119
}
11201120

1121-
def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> {
1121+
def FuncOp : TT_Op<"func", [
1122+
AffineScope, AutomaticAllocationScope, CallableOpInterface,
1123+
FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface,
1124+
HasParent<"ModuleOp">
1125+
]> {
11221126
let summary = "An operation with a name containing a single `SSACFG` region";
11231127
let description = [{
11241128
Operations within the function cannot implicitly capture values defined

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ template <> struct hash<CacheKey> {
3939

4040
namespace mlir::triton::gpu {
4141

42+
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
43+
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
44+
constexpr static char AttrTargetName[] = "ttg.target";
45+
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
46+
47+
int lookupNumWarps(Operation *op);
48+
4249
class LinearLayoutCache {
4350
public:
4451
std::optional<LinearLayout> get(const CacheKey &key) {

include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,13 @@ def TritonGPU_Dialect : Dialect {
2020
];
2121

2222
let extraClassDeclaration = [{
23-
static std::string getNumWarpsAttrName() { return "ttg.num-warps"; }
24-
static int getNumWarps(ModuleOp mod) {
25-
if (!mod->hasAttr("ttg.num-warps"))
26-
llvm::report_fatal_error(
27-
"TritonGPU module should contain a ttg.num-warps attribute");
28-
return cast<IntegerAttr>(mod->getAttr("ttg.num-warps")).getInt();
29-
}
30-
static int getNumCTAs(ModuleOp mod) {
31-
if (!mod->hasAttr("ttg.num-ctas"))
32-
return 1;
33-
return cast<IntegerAttr>(mod->getAttr("ttg.num-ctas")).getInt();
34-
}
3523
void registerTypes();
3624

37-
static std::string getThreadsPerWarpAttrName() { return "ttg.threads-per-warp"; }
38-
39-
static int getThreadsPerWarp(ModuleOp mod) {
40-
Attribute threadsPerWarp = mod->getDiscardableAttr("ttg.threads-per-warp");
41-
if(!threadsPerWarp) {
42-
return 32;
43-
}
44-
return cast<IntegerAttr>(threadsPerWarp).getInt();
45-
}
46-
4725
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
4826

27+
static int getNumCTAs(ModuleOp mod);
28+
static int getThreadsPerWarp(ModuleOp mod);
29+
4930
private:
5031
LinearLayoutCache llCache;
5132
}];

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@
4242
#define GET_OP_CLASSES
4343
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc"
4444

45-
namespace mlir {
46-
namespace triton {
47-
namespace nvidia_gpu {
45+
namespace mlir::triton::nvidia_gpu {
4846

4947
struct TensorMemory : public SideEffects::Resource::Base<TensorMemory> {
5048
StringRef getName() final { return "<TensorMemory>"; }
@@ -63,12 +61,10 @@ Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
6361
ArrayRef<int64_t> shape, unsigned numWarps,
6462
triton::gpu::CTALayoutAttr ctaLayout);
6563

66-
bool isDistributedLayoutTMemCompatible(ModuleOp mod,
64+
bool isDistributedLayoutTMemCompatible(Operation *op,
6765
RankedTensorType tensorType,
6866
gpu::MemDescType memType);
6967

70-
} // namespace nvidia_gpu
71-
} // namespace triton
72-
} // namespace mlir
68+
} // namespace mlir::triton::nvidia_gpu
7369

7470
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,6 @@ def TritonNvidiaGPU_Dialect : Dialect {
4141
"mlir::gpu::GPUDialect",
4242
];
4343

44-
let extraClassDeclaration = [{
45-
static std::string getNumWarpsAttrName() { return "ttg.num-warps"; }
46-
static int getNumWarps(ModuleOp mod) {
47-
if(!mod->hasAttr("ttg.num-warps"))
48-
llvm::report_fatal_error(
49-
"TritonGPU module should contain a ttg.num-warps attribute");
50-
return cast<IntegerAttr>(mod->getAttr("ttg.num-warps")).getInt();
51-
}
52-
static int getNumCTAs(ModuleOp mod) {
53-
if(!mod->hasAttr("ttg.num-ctas"))
54-
llvm::report_fatal_error(
55-
"TritonGPU module should contain a ttg.num-ctas attribute");
56-
return cast<IntegerAttr>(mod->getAttr("ttg.num-ctas")).getInt();
57-
}
58-
}];
5944
let useDefaultAttributePrinterParser = 1;
6045
let usePropertiesForAttributes = 1;
6146
}

lib/Analysis/Utility.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -302,8 +302,7 @@ bool ScanLoweringHelper::isSupported() {
302302
}
303303

304304
unsigned ScanLoweringHelper::getScratchSizeInElems() {
305-
auto mod = scanOp->getParentOfType<ModuleOp>();
306-
unsigned numWarps = TritonGPUDialect::getNumWarps(mod);
305+
unsigned numWarps = lookupNumWarps(scanOp);
307306
unsigned numNonAxisElementsPerWarp =
308307
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
309308
unsigned numElements = numWarps * numNonAxisElementsPerWarp *
@@ -726,8 +725,7 @@ bool supportMMA(triton::DotOp op, int version) {
726725
auto retType = op.getType();
727726
auto retShapePerCTA = getShapePerCTA(retType);
728727
auto rank = retShapePerCTA.size();
729-
auto mod = op->getParentOfType<ModuleOp>();
730-
int numWarps = TritonGPUDialect::getNumWarps(mod);
728+
int numWarps = lookupNumWarps(op);
731729
if (aElemTy.isInteger() || bElemTy.isInteger() ||
732730
retType.getElementType().isInteger())
733731
return false;
@@ -749,8 +747,7 @@ bool supportMMA(triton::DotOp op, int version) {
749747
return false;
750748
auto retShapePerCTA = getShapePerCTA(retType);
751749
auto rank = retShapePerCTA.size();
752-
auto mod = op->getParentOfType<ModuleOp>();
753-
int numWarps = TritonGPUDialect::getNumWarps(mod);
750+
int numWarps = lookupNumWarps(op);
754751
// TODO(Keren): for now, fallback to MMAv2 if handling batch matmul.
755752
if (rank == 3)
756753
return false;

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ namespace mlir::triton::gpu {
2020

2121
void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
2222
ShortcutFn shortcutFn) {
23-
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
24-
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module);
25-
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
23+
MLIRContext *ctx = module.getContext();
24+
int numCTAs = TritonGPUDialect::getNumCTAs(module);
25+
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2626

2727
module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
2828
OpBuilder builder(cvtOp);
@@ -31,28 +31,32 @@ void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
3131
auto srcMma = dyn_cast<MmaEncodingTrait>(srcType.getEncoding());
3232
auto dstDotOp =
3333
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
34-
if (srcMma && dstDotOp && !shortcutFn(srcType, dstType)) {
35-
auto tmpType = RankedTensorType::get(
36-
dstType.getShape(), dstType.getElementType(),
37-
triton::gpu::BlockedEncodingAttr::get(
38-
module.getContext(), srcType.getShape(), getSizePerThread(srcMma),
39-
getOrder(srcMma), numWarps, threadsPerWarp, numCTAs));
40-
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
41-
cvtOp.getLoc(), tmpType, cvtOp.getSrc());
42-
addAttrs(tmp, cvtOp->getAttrs());
43-
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
44-
cvtOp.getLoc(), dstType, tmp);
45-
addAttrs(newConvert, cvtOp->getAttrs());
46-
cvtOp.replaceAllUsesWith(newConvert.getResult());
47-
cvtOp.erase();
48-
}
34+
if (!srcMma || !dstDotOp || shortcutFn(srcType, dstType))
35+
return;
36+
37+
int numWarps = lookupNumWarps(cvtOp);
38+
auto enc = BlockedEncodingAttr::get(
39+
ctx, srcType.getShape(), getSizePerThread(srcMma), getOrder(srcMma),
40+
numWarps, threadsPerWarp, numCTAs);
41+
auto tmpType = RankedTensorType::get(dstType.getShape(),
42+
dstType.getElementType(), enc);
43+
44+
auto tmp = builder.create<ConvertLayoutOp>(cvtOp.getLoc(), tmpType,
45+
cvtOp.getSrc());
46+
addAttrs(tmp, cvtOp->getAttrs());
47+
auto newConvert =
48+
builder.create<ConvertLayoutOp>(cvtOp.getLoc(), dstType, tmp);
49+
addAttrs(newConvert, cvtOp->getAttrs());
50+
51+
cvtOp.replaceAllUsesWith(newConvert.getResult());
52+
cvtOp.erase();
4953
});
5054
}
5155

5256
void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
53-
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
5457
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module);
5558
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
59+
5660
module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
5761
OpBuilder builder(cvtOp);
5862
auto srcType = cast<RankedTensorType>(cvtOp.getSrc().getType());

0 commit comments

Comments
 (0)