Skip to content

Commit ae3d625

Browse files
committed
Fix performance regression for gemm-preop-exp
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 0215a16 commit ae3d625

File tree

2 files changed

+22
-22
lines changed

2 files changed

+22
-22
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,9 @@ struct TritonIntelGPUMaterializeBlockPointerPass
102102
return;
103103

104104
const bool isRowMajor = fastChangeDim == rank - 1;
105-
if (auto dotLayout = getDotLayout(loadOp)) {
105+
std::optional<ttg::DotOperandEncodingAttr> dotLayout =
106+
getDotLayout(loadOp);
107+
if (dotLayout) {
106108
// Check if the load is being used by a tt.dot operation, and if so is
107109
// this the first operand and is it a transposed row major matrix. If
108110
// so, skip the block ptr attribute as performance is worse than if we
@@ -163,8 +165,9 @@ struct TritonIntelGPUMaterializeBlockPointerPass
163165
allUserHaveIdenticalLayout(users)) {
164166
Attribute firstUserLayout =
165167
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
166-
return llvm::dyn_cast_if_present<ttg::DotOperandEncodingAttr>(
167-
firstUserLayout);
168+
if (isa<ttg::DotOperandEncodingAttr>(firstUserLayout))
169+
return dyn_cast<ttg::DotOperandEncodingAttr>(firstUserLayout);
170+
return std::nullopt;
168171
}
169172

170173
return std::nullopt;

third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
#include "llvm/Support/Debug.h"
1111

1212
#define DEBUG_TYPE "tritonintelgpu-pipeline"
13+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
14+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
1315

1416
using namespace mlir;
1517
namespace tt = mlir::triton;
@@ -55,30 +57,25 @@ static ttg::DotOperandEncodingAttr getDotEncodingFromUser(Operation *user) {
5557
if (!tensorType)
5658
return nullptr;
5759

58-
if (isa<ttg::SharedEncodingAttr>(tensorType.getEncoding()))
59-
return allTransitiveUsesHaveDotEncoding(res);
60-
61-
return llvm::dyn_cast_or_null<ttg::DotOperandEncodingAttr>(
62-
tensorType.getEncoding());
60+
Attribute layout = tensorType.getEncoding();
61+
return isa<ttg::SharedEncodingAttr, ttg::BlockedEncodingAttr>(layout)
62+
? allTransitiveUsesHaveDotEncoding(res)
63+
: llvm::dyn_cast_or_null<ttg::DotOperandEncodingAttr>(layout);
6364
}
6465

6566
/// If all the transitive uses of the given value are used by a convert to the
6667
/// same dot operand encoding, return the encoding. Otherwise return nullptr.
6768
static ttg::DotOperandEncodingAttr allTransitiveUsesHaveDotEncoding(Value val) {
6869
ttg::DotOperandEncodingAttr attr{nullptr};
69-
LLVM_DEBUG(llvm::dbgs() << "Checking users of " << val << "\n");
70+
LDBG("Checking users of " << val);
7071
for (Operation *user : val.getUsers()) {
71-
ttg::DotOperandEncodingAttr dotAttr;
72-
if (isa<triton::DotOp>(user)) {
73-
auto tensorType = cast<RankedTensorType>(val.getType());
74-
dotAttr = dyn_cast<ttg::DotOperandEncodingAttr>(tensorType.getEncoding());
75-
} else {
76-
dotAttr = getDotEncodingFromUser(user);
77-
}
72+
ttg::DotOperandEncodingAttr dotAttr =
73+
isa<triton::DotOp>(user)
74+
? dyn_cast<ttg::DotOperandEncodingAttr>(
75+
cast<RankedTensorType>(val.getType()).getEncoding())
76+
: getDotEncodingFromUser(user);
7877
if (!dotAttr || (attr != nullptr && attr != dotAttr)) {
79-
LLVM_DEBUG({
80-
llvm::dbgs() << "no dot attribute found for user: " << user << "\n";
81-
});
78+
LDBG("no dot attribute found for user: " << *user);
8279
return nullptr;
8380
}
8481
attr = dotAttr;
@@ -292,14 +289,14 @@ bool ttgi::preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages,
292289
SmallVector<LoadDotOperand> loads;
293290
collectOpsToPipeline(forOp, loads, supportRegularPtr);
294291
if (loads.empty()) {
295-
LLVM_DEBUG(llvm::dbgs() << "No loads to pipeline\n");
292+
LDBG("No loads to pipeline");
296293
return false;
297294
}
298295

299296
LLVM_DEBUG({
300-
llvm::dbgs() << "Loads to pipeline:\n";
297+
DBGS() << "Loads to pipeline:\n";
301298
for (const LoadDotOperand &load : loads)
302-
llvm::dbgs() << " " << *load.load << "\n";
299+
DBGS() << " " << *load.load << "\n";
303300
});
304301

305302
// 2. Create the prefetching operations for the loads collected.

0 commit comments

Comments
 (0)