Skip to content

Commit 6b1642e

Browse files
[Intel] Refactor numWarps lookups to be from the op
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 539f2ed commit 6b1642e

File tree

14 files changed

+19
-29
lines changed

14 files changed

+19
-29
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1062,8 +1062,7 @@ class DecomposeScaledBlocked
10621062

10631063
RankedTensorType oldRetType = dotOp.getType();
10641064
auto retShapePerCTA = getShapePerCTA(oldRetType);
1065-
auto mod = dotOp->getParentOfType<mlir::ModuleOp>();
1066-
int numWarps = TritonGPUDialect::getNumWarps(mod);
1065+
int numWarps = lookupNumWarps(dotOp);
10671066
auto CTALayout = getCTALayout(oldRetType.getEncoding());
10681067

10691068
auto instrShape = mmaVersionToInstrShape(

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ class ScaledBlockedToMFMA final : public OpRewritePattern<triton::DotScaledOp> {
562562
auto moduleOp = dotOp->getParentOfType<ModuleOp>();
563563

564564
ttg::CTALayoutAttr ctaLayout = ttg::getCTALayout(oldRetType.getEncoding());
565-
int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp);
565+
int numWarps = ttg::lookupNumWarps(dotOp);
566566
int numThreads = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp);
567567

568568
// Choose a suitable MFMA instruction for this scaled dot op.

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ DPASAnalysis::canUseDPAS(FunctionOpInterface funcOp) const {
6161
// Verify whether the module has the correct number of threads per warp.
6262
// Note: if the module doesn't then return 'Result::Maybe' to allow the caller
6363
// to set warp size.
64-
Attribute threadsPerWarpAttr =
65-
mod->getDiscardableAttr(TritonGPUDialect::getThreadsPerWarpAttrName());
64+
Attribute threadsPerWarpAttr = mod->getDiscardableAttr(AttrNumThreadsPerWarp);
6665
if (!threadsPerWarpAttr)
6766
return Result::Maybe;
6867

third_party/intel/lib/TritonAnnotateModule/TritonAnnotateModule.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,6 @@ struct TritonAnnotateModule
4848
void setThreadsPerWarp(ModuleOp &mod,
4949
const DPASAnalysis &dpasAnalysis) const {
5050
Builder builder(mod);
51-
const std::string &AttrNumThreadsPerWarp =
52-
TritonGPUDialect::getThreadsPerWarpAttrName();
53-
5451
mod.walk([&](FunctionOpInterface funcOp) {
5552
// FIXME: DPAS lowering only implemented for 16 threads per warp, i.e.,
5653
// DPAS is not used for devices like ATS.

third_party/intel/lib/TritonIntelGPUToLLVM/HistogramOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ struct HistogramOpConversion
170170
assert((numThreadsPerWarp == 16 || numThreadsPerWarp == 32 ||
171171
numThreadsPerWarp == 64) &&
172172
"Only supports 16, 32 or 64 threads per warp");
173-
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
173+
int numWarps = triton::gpu::lookupNumWarps(op);
174174
// Pad out the bins so that we have at least one bin per thread within a
175175
// warp.
176176
numBins = std::max(numBins, numThreadsPerWarp);

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ struct PrefetchOpConversion
366366
std::swap(tensorShape[0], tensorShape[1]);
367367
}
368368

369-
unsigned numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
369+
unsigned numWarps = triton::gpu::lookupNumWarps(op);
370370

371371
SmallVector<unsigned, 2> shapePerWarp =
372372
get2DPrefetchShapePerWarp(tensorType);

third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ struct ReduceOpConversion
289289

290290
auto mod = op.getOperation()->getParentOfType<ModuleOp>();
291291
unsigned numLanes = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
292-
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
292+
int numWarps = triton::gpu::lookupNumWarps(op.getOperation());
293293
int numThreads = numLanes * numWarps;
294294

295295
Value threadId = getThreadId(rewriter, loc);

third_party/intel/lib/TritonIntelGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ struct ConvertTritonGPUToLLVM
101101
TritonIntelGPUToLLVMTypeConverter typeConverter(context, option, targetInfo,
102102
isAdvancedPathEnabled);
103103
TritonLLVMConversionTarget convTarget(*context);
104-
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
104+
int numWarps = triton::gpu::lookupNumWarps(&*mod.getOps().begin());
105105
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
106106
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
107107

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,14 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
105105

106106
// Create DPAS encoding for the given number of warps
107107
ArrayRef<int64_t> retShape = oldRetType.getShape();
108-
ModuleOp mod = funcOp->getParentOfType<ModuleOp>();
109-
unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
108+
unsigned numWarps = ttg::lookupNumWarps(funcOp);
110109

111110
TensorValue a = dotOp.getA();
112111
TensorValue b = dotOp.getB();
113112
auto oldAType = cast<RankedTensorType>(a.getType());
114113
auto oldBType = cast<RankedTensorType>(b.getType());
115114

115+
ModuleOp mod = funcOp->getParentOfType<ModuleOp>();
116116
auto dpasCap = ttgi::DpasEncodingAttr::getDPASCapability(mod);
117117
Type elemType = oldAType.getElementType();
118118
unsigned opsPerChan = ttgi::DpasEncodingAttr::getOpsPerChannel(elemType);
@@ -295,7 +295,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
295295
assert(opDesc.scale && "Expecting valid operand & scale");
296296

297297
MLIRContext *ctx = opDesc.op.getContext();
298-
unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
298+
unsigned numWarps = ttg::lookupNumWarps(&*rewriter.getInsertionPoint());
299299
unsigned warpSize = ttg::TritonGPUDialect::getThreadsPerWarp(mod);
300300
unsigned opsPerChannel = dpasEnc.getOpsPerChannel();
301301
unsigned rank = retType.getRank();
@@ -372,7 +372,7 @@ class DecomposeScaledBlocked : public OpRewritePattern<tt::DotScaledOp> {
372372
aScale ? b.getType().getElementType() : a.getType().getElementType();
373373
unsigned opsPerChan =
374374
ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType);
375-
unsigned numWarps = ttg::TritonGPUDialect::getNumWarps(mod);
375+
unsigned numWarps = ttg::lookupNumWarps(scaledDotOp);
376376
SmallVector<unsigned> warpsPerTile = {numWarps, 1};
377377

378378
ArrayRef<int64_t> retShape = scaledDotOp.getType().getShape();

third_party/intel/lib/TritonIntelGPUTransforms/Coalesce.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ struct CoalescePass
372372
if (!refTensorType || !refTensorType.getEncoding())
373373
return;
374374

375-
int numWarps = ttg::TritonGPUDialect::getNumWarps(moduleOp);
375+
int numWarps = ttg::lookupNumWarps(curr);
376376
int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(moduleOp);
377377
setCoalescedEncoding(axisInfoAnalysis, curr, numWarps, threadsPerWarp,
378378
layoutMap);

0 commit comments

Comments
 (0)