Skip to content

Commit 8b6a9d2

Browse files
committed
Axisinfo decoupling
1 parent 0359ed0 commit 8b6a9d2

File tree

2 files changed

+9
-31
lines changed

2 files changed

+9
-31
lines changed
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#ifndef ILUVATAR_TRITON_ANALYSIS_AXISINFO_H
22
#define ILUVATAR_TRITON_ANALYSIS_AXISINFO_H
33

4-
#define FLAGTREE_SPEC_CorexFlag
5-
#define FLAGTREE_SPEC_AxisInfo_getCorexFlag
4+
#define FLAGTREE_SPEC_AxisInfo_CorexFlag
65
#define FLAGTREE_SPEC_AxisInfo_initPessimisticStateFromFunc_ARG AxisInfo::DimVectorT *
76

87
#endif

third_party/iluvatar/include/triton/Analysis/AxisInfo.h

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class AxisInfo {
2727
typedef SmallVector<int64_t> DimVectorT;
2828

2929
public:
30-
#ifndef __ILUVATAR__
30+
#ifndef FLAGTREE_SPEC_AxisInfo_CorexFlag
3131
AxisInfo() : AxisInfo({}, {}, {}) {}
3232

3333
AxisInfo(DimVectorT contiguity, DimVectorT divisibility, DimVectorT constancy)
@@ -127,7 +127,7 @@ class AxisInfo {
127127
int64_t getConstancy(size_t dim) const { return constancy[dim]; }
128128
const DimVectorT &getConstancy() const { return constancy; }
129129

130-
#ifdef FLAGTREE_SPEC_AxisInfo_getCorexFlag
130+
#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag
131131
// corexFlag is used to determine whether special instructions can be used to
132132
// accelerate data loading.
133133
int64_t getCorexFlag(size_t dim) const { return corexFlag[dim]; }
@@ -151,42 +151,20 @@ class AxisInfo {
151151
DimVectorT *divisibility, DimVectorT *constancy);
152152
#endif
153153

154-
#ifndef __ILUVATAR__
155154
bool operator==(const AxisInfo &other) const {
156155
return contiguity == other.contiguity &&
157156
divisibility == other.divisibility && constancy == other.constancy &&
157+
#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag
158+
corexFlag == other.corexFlag &&
159+
#endif
158160
constantValue == other.constantValue;
159161
}
160-
#else
161-
bool operator==(const AxisInfo &other) const {
162-
return contiguity == other.contiguity &&
163-
divisibility == other.divisibility && constancy == other.constancy &&
164-
corexFlag == other.corexFlag && constantValue == other.constantValue;
165-
}
166-
#endif
167162

168163
static AxisInfo getPessimisticValueState(Value value);
169164

170165
// The gcd of both arguments for each dimension
171166
static AxisInfo join(const AxisInfo &lhs, const AxisInfo &rhs);
172167

173-
#ifndef __ILUVATAR__
174-
void print(raw_ostream &os) const {
175-
auto print = [&](StringRef name, DimVectorT vec) {
176-
os << name << " = [";
177-
llvm::interleaveComma(vec, os);
178-
os << "]";
179-
};
180-
print("contiguity", contiguity);
181-
print(", divisibility", divisibility);
182-
print(", constancy", constancy);
183-
os << ", constant_value = ";
184-
if (constantValue)
185-
os << *constantValue;
186-
else
187-
os << "<none>";
188-
}
189-
#else
190168
void print(raw_ostream &os) const {
191169
auto print = [&](StringRef name, DimVectorT vec) {
192170
os << name << " = [";
@@ -196,14 +174,15 @@ class AxisInfo {
196174
print("contiguity", contiguity);
197175
print(", divisibility", divisibility);
198176
print(", constancy", constancy);
177+
#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag
199178
print(", corexflag", corexFlag);
179+
#endif
200180
os << ", constant_value = ";
201181
if (constantValue)
202182
os << *constantValue;
203183
else
204184
os << "<none>";
205185
}
206-
#endif
207186

208187
private:
209188
DimVectorT contiguity;
@@ -212,7 +191,7 @@ class AxisInfo {
212191

213192
// The constant value of the lattice if we can infer it.
214193
std::optional<int64_t> constantValue;
215-
#ifdef FLAGTREE_SPEC_CorexFlag
194+
#ifdef FLAGTREE_SPEC_AxisInfo_CorexFlag
216195
DimVectorT corexFlag;
217196
#endif
218197
};

0 commit comments

Comments
 (0)