Skip to content

Commit f1909a5

Browse files
committed
Address code review comments
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 0d82630 commit f1909a5

File tree

2 files changed

+8
-9
lines changed

2 files changed

+8
-9
lines changed

third_party/intel/include/Analysis/AxisInfo.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,10 @@ class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis {
4242
}
4343
}
4444

45-
AxisInfo *getAxisInfo(Value value) const {
45+
AxisInfo *getAxisInfo(Value value) {
4646
auto funcOp =
4747
value.getParentRegion()->getParentOfType<FunctionOpInterface>();
48-
auto *axisInfoMap =
49-
const_cast<ModuleAxisInfoAnalysis *>(this)->getFuncData(funcOp);
48+
auto *axisInfoMap = getFuncData(funcOp);
5049
if (!axisInfoMap) {
5150
return nullptr;
5251
}
@@ -57,9 +56,9 @@ class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis {
5756
return &(it->second);
5857
}
5958

60-
unsigned getPtrContiguity(Value ptr) const;
61-
unsigned getPtrAlignment(Value ptr) const;
62-
unsigned getMaskAlignment(Value mask) const;
59+
unsigned getPtrContiguity(Value ptr);
60+
unsigned getPtrAlignment(Value ptr);
61+
unsigned getMaskAlignment(Value mask);
6362

6463
private:
6564
void initialize(FunctionOpInterface funcOp);

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,7 @@ void AxisInfoAnalysis::visitForOpInductionVar(
11591159

11601160
} // anonymous namespace
11611161

1162-
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) const {
1162+
unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) {
11631163
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
11641164
if (!tensorTy)
11651165
return 1;
@@ -1181,7 +1181,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrContiguity(Value ptr) const {
11811181
return contiguity;
11821182
}
11831183

1184-
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) const {
1184+
unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) {
11851185
auto tensorTy = ttgi::getRankedTensorType(ptr.getType());
11861186
if (!tensorTy)
11871187
return 1;
@@ -1211,7 +1211,7 @@ unsigned ModuleAxisInfoAnalysis::getPtrAlignment(Value ptr) const {
12111211
return alignment;
12121212
}
12131213

1214-
unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) const {
1214+
unsigned ModuleAxisInfoAnalysis::getMaskAlignment(Value mask) {
12151215
auto tensorTy = ttgi::getRankedTensorType(mask.getType());
12161216
if (!tensorTy)
12171217
return 1;

0 commit comments

Comments
 (0)