Skip to content

Commit 28396a7

Browse files
authored
[TritonGPU] Refactor numWarps lookups to be from the op (NFC) (#5891)
This is in preparation for warp specialization, which will turn the number of warps into a scoped property of regions. This PR just rearranges the API for looking up the number of warps. In the next PR, the `"ttg.num-warps"` attribute will be moved to `tt.func`.
1 parent 27c8363 commit 28396a7

36 files changed

+187
-240
lines changed

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
100100
PatternBenefit benefit);
101101

102102
void populateFuncOpConversionPattern(LLVMTypeConverter &typeConverter,
103-
RewritePatternSet &patterns, int numWarps,
103+
RewritePatternSet &patterns,
104104
const TargetInfoBase &targetInfo,
105105
PatternBenefit benefit);
106106

include/triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,6 @@ template <typename T> class OperationPass;
1212

1313
namespace triton {
1414

15-
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
16-
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
17-
constexpr static char AttrTargetName[] = "ttg.target";
18-
19-
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
20-
2115
// Create the pass with numWarps passed from cl::opt.
2216
std::unique_ptr<OperationPass<ModuleOp>> createConvertTritonToTritonGPUPass();
2317

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class DialectVerifyTensorLayoutInterface
9696
DialectVerifyTensorLayoutInterface(Dialect *dialect) : Base(dialect) {}
9797

9898
virtual LogicalResult
99-
verifyTensorLayout(Attribute layout, RankedTensorType type, ModuleOp module,
99+
verifyTensorLayout(Attribute layout, RankedTensorType type, Operation *op,
100100
function_ref<InFlightDiagnostic()> emitError) const = 0;
101101
};
102102

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1122,7 +1122,11 @@ def CallOp : TT_Op<"call", [CallOpInterface, /*MemRefsNormalizable, */DeclareOpI
11221122
}];
11231123
}
11241124

1125-
def FuncOp : TT_Op<"func", [AffineScope, AutomaticAllocationScope, CallableOpInterface, FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface]> {
1125+
def FuncOp : TT_Op<"func", [
1126+
AffineScope, AutomaticAllocationScope, CallableOpInterface,
1127+
FunctionOpInterface, IsolatedFromAbove, OpAsmOpInterface,
1128+
HasParent<"ModuleOp">
1129+
]> {
11261130
let summary = "An operation with a name containing a single `SSACFG` region";
11271131
let description = [{
11281132
Operations within the function cannot implicitly capture values defined

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,13 @@ template <> struct hash<CacheKey> {
3939

4040
namespace mlir::triton::gpu {
4141

42+
constexpr static char AttrNumWarpsName[] = "ttg.num-warps";
43+
constexpr static char AttrNumCTAsName[] = "ttg.num-ctas";
44+
constexpr static char AttrTargetName[] = "ttg.target";
45+
constexpr static char AttrNumThreadsPerWarp[] = "ttg.threads-per-warp";
46+
47+
int lookupNumWarps(Operation *op);
48+
4249
class LinearLayoutCache {
4350
public:
4451
std::optional<LinearLayout> get(const CacheKey &key) {

include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td

Lines changed: 3 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,32 +20,13 @@ def TritonGPU_Dialect : Dialect {
2020
];
2121

2222
let extraClassDeclaration = [{
23-
static std::string getNumWarpsAttrName() { return "ttg.num-warps"; }
24-
static int getNumWarps(ModuleOp mod) {
25-
if (!mod->hasAttr("ttg.num-warps"))
26-
llvm::report_fatal_error(
27-
"TritonGPU module should contain a ttg.num-warps attribute");
28-
return cast<IntegerAttr>(mod->getAttr("ttg.num-warps")).getInt();
29-
}
30-
static int getNumCTAs(ModuleOp mod) {
31-
if (!mod->hasAttr("ttg.num-ctas"))
32-
return 1;
33-
return cast<IntegerAttr>(mod->getAttr("ttg.num-ctas")).getInt();
34-
}
3523
void registerTypes();
3624

37-
static std::string getThreadsPerWarpAttrName() { return "ttg.threads-per-warp"; }
38-
39-
static int getThreadsPerWarp(ModuleOp mod) {
40-
Attribute threadsPerWarp = mod->getDiscardableAttr("ttg.threads-per-warp");
41-
if(!threadsPerWarp) {
42-
return 32;
43-
}
44-
return cast<IntegerAttr>(threadsPerWarp).getInt();
45-
}
46-
4725
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
4826

27+
static int getNumCTAs(ModuleOp mod);
28+
static int getThreadsPerWarp(ModuleOp mod);
29+
4930
private:
5031
LinearLayoutCache llCache;
5132
}];

include/triton/Dialect/TritonNvidiaGPU/IR/Dialect.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@
4242
#define GET_OP_CLASSES
4343
#include "triton/Dialect/TritonNvidiaGPU/IR/Ops.h.inc"
4444

45-
namespace mlir {
46-
namespace triton {
47-
namespace nvidia_gpu {
45+
namespace mlir::triton::nvidia_gpu {
4846

4947
struct TensorMemory : public SideEffects::Resource::Base<TensorMemory> {
5048
StringRef getName() final { return "<TensorMemory>"; }
@@ -63,12 +61,10 @@ Attribute getTmemCompatibleLayout(unsigned M, unsigned N,
6361
ArrayRef<int64_t> shape, unsigned numWarps,
6462
triton::gpu::CTALayoutAttr ctaLayout);
6563

66-
bool isDistributedLayoutTMemCompatible(ModuleOp mod,
64+
bool isDistributedLayoutTMemCompatible(Operation *op,
6765
RankedTensorType tensorType,
6866
gpu::MemDescType memType);
6967

70-
} // namespace nvidia_gpu
71-
} // namespace triton
72-
} // namespace mlir
68+
} // namespace mlir::triton::nvidia_gpu
7369

7470
#endif // TRITON_DIALECT_TRITONNVIDIAGPU_IR_DIALECT_H_

include/triton/Dialect/TritonNvidiaGPU/IR/TritonNvidiaGPUDialect.td

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,6 @@ def TritonNvidiaGPU_Dialect : Dialect {
4141
"mlir::gpu::GPUDialect",
4242
];
4343

44-
let extraClassDeclaration = [{
45-
static std::string getNumWarpsAttrName() { return "ttg.num-warps"; }
46-
static int getNumWarps(ModuleOp mod) {
47-
if(!mod->hasAttr("ttg.num-warps"))
48-
llvm::report_fatal_error(
49-
"TritonGPU module should contain a ttg.num-warps attribute");
50-
return cast<IntegerAttr>(mod->getAttr("ttg.num-warps")).getInt();
51-
}
52-
static int getNumCTAs(ModuleOp mod) {
53-
if(!mod->hasAttr("ttg.num-ctas"))
54-
llvm::report_fatal_error(
55-
"TritonGPU module should contain a ttg.num-ctas attribute");
56-
return cast<IntegerAttr>(mod->getAttr("ttg.num-ctas")).getInt();
57-
}
58-
}];
5944
let useDefaultAttributePrinterParser = 1;
6045
let usePropertiesForAttributes = 1;
6146
}

lib/Analysis/Utility.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,7 @@ bool ScanLoweringHelper::isSupported() {
296296
}
297297

298298
unsigned ScanLoweringHelper::getScratchSizeInElems() {
299-
auto mod = scanOp->getParentOfType<ModuleOp>();
300-
unsigned numWarps = TritonGPUDialect::getNumWarps(mod);
299+
unsigned numWarps = lookupNumWarps(scanOp);
301300
unsigned numNonAxisElementsPerWarp =
302301
getNonAxisNumThreadsPerWarp() * getNonAxisNumElementsPerThread();
303302
unsigned numElements = numWarps * numNonAxisElementsPerWarp *
@@ -720,8 +719,7 @@ bool supportMMA(triton::DotOp op, int version) {
720719
auto retType = op.getType();
721720
auto retShapePerCTA = getShapePerCTA(retType);
722721
auto rank = retShapePerCTA.size();
723-
auto mod = op->getParentOfType<ModuleOp>();
724-
int numWarps = TritonGPUDialect::getNumWarps(mod);
722+
int numWarps = lookupNumWarps(op);
725723
if (aElemTy.isInteger() || bElemTy.isInteger() ||
726724
retType.getElementType().isInteger())
727725
return false;
@@ -743,8 +741,7 @@ bool supportMMA(triton::DotOp op, int version) {
743741
return false;
744742
auto retShapePerCTA = getShapePerCTA(retType);
745743
auto rank = retShapePerCTA.size();
746-
auto mod = op->getParentOfType<ModuleOp>();
747-
int numWarps = TritonGPUDialect::getNumWarps(mod);
744+
int numWarps = lookupNumWarps(op);
748745
// TODO(Keren): for now, fallback to MMAv2 if handling batch matmul.
749746
if (rank == 3)
750747
return false;

lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ namespace mlir::triton::gpu {
2020

2121
void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
2222
ShortcutFn shortcutFn) {
23-
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
24-
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module);
25-
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
23+
MLIRContext *ctx = module.getContext();
24+
int numCTAs = TritonGPUDialect::getNumCTAs(module);
25+
int threadsPerWarp = TritonGPUDialect::getThreadsPerWarp(module);
2626

2727
module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
2828
OpBuilder builder(cvtOp);
@@ -31,28 +31,32 @@ void decomposeTensorCoreToDotLayoutConversion(ModuleOp module,
3131
auto srcMma = dyn_cast<MmaEncodingTrait>(srcType.getEncoding());
3232
auto dstDotOp =
3333
dyn_cast<triton::gpu::DotOperandEncodingAttr>(dstType.getEncoding());
34-
if (srcMma && dstDotOp && !shortcutFn(srcType, dstType)) {
35-
auto tmpType = RankedTensorType::get(
36-
dstType.getShape(), dstType.getElementType(),
37-
triton::gpu::BlockedEncodingAttr::get(
38-
module.getContext(), srcType.getShape(), getSizePerThread(srcMma),
39-
getOrder(srcMma), numWarps, threadsPerWarp, numCTAs));
40-
auto tmp = builder.create<triton::gpu::ConvertLayoutOp>(
41-
cvtOp.getLoc(), tmpType, cvtOp.getSrc());
42-
addAttrs(tmp, cvtOp->getAttrs());
43-
auto newConvert = builder.create<triton::gpu::ConvertLayoutOp>(
44-
cvtOp.getLoc(), dstType, tmp);
45-
addAttrs(newConvert, cvtOp->getAttrs());
46-
cvtOp.replaceAllUsesWith(newConvert.getResult());
47-
cvtOp.erase();
48-
}
34+
if (!srcMma || !dstDotOp || shortcutFn(srcType, dstType))
35+
return;
36+
37+
int numWarps = lookupNumWarps(cvtOp);
38+
auto enc = BlockedEncodingAttr::get(
39+
ctx, srcType.getShape(), getSizePerThread(srcMma), getOrder(srcMma),
40+
numWarps, threadsPerWarp, numCTAs);
41+
auto tmpType = RankedTensorType::get(dstType.getShape(),
42+
dstType.getElementType(), enc);
43+
44+
auto tmp = builder.create<ConvertLayoutOp>(cvtOp.getLoc(), tmpType,
45+
cvtOp.getSrc());
46+
addAttrs(tmp, cvtOp->getAttrs());
47+
auto newConvert =
48+
builder.create<ConvertLayoutOp>(cvtOp.getLoc(), dstType, tmp);
49+
addAttrs(newConvert, cvtOp->getAttrs());
50+
51+
cvtOp.replaceAllUsesWith(newConvert.getResult());
52+
cvtOp.erase();
4953
});
5054
}
5155

5256
void decomposeBlockedToDotLayoutConversion(ModuleOp module) {
53-
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(module);
5457
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(module);
5558
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module);
59+
5660
module.walk([&](triton::gpu::ConvertLayoutOp cvtOp) -> void {
5761
OpBuilder builder(cvtOp);
5862
auto srcType = cast<RankedTensorType>(cvtOp.getSrc().getType());

0 commit comments

Comments
 (0)