Skip to content

Commit 5b94131

Browse files
authored
Remove RewriteTensorPointer from the optimization pipeline (#2584)
Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent eff7b30 commit 5b94131

File tree

5 files changed

+78
-67
lines changed

5 files changed

+78
-67
lines changed

third_party/intel/backend/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ def make_ttgir(mod, metadata, opt, properties):
235235
intel.passes.ttgpuir.add_accelerate_matmul(pm)
236236
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
237237
intel.passes.ttgpuir.add_materialize_block_pointer(pm)
238-
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
238+
if os.getenv("TRITON_INTEL_REWRITE_TENSOR_POINTER", "0") == "1":
239+
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
239240
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)
240241

241242
intel.passes.ttgpuir.add_coalesce(pm)

third_party/intel/include/Analysis/AxisInfo.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ namespace mlir::triton::intel {
1212
// axis info based on the axis info of all the callers. In the future, we can
1313
// perform optimization using function cloning so that each call site will have
1414
// unique axis info.
15-
1615
class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis {
1716
public:
1817
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)

third_party/intel/lib/Analysis/AxisInfo.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,13 +1030,24 @@ class MakeTensorPtrOpAxisInfoVisitor final
10301030
strideInfo[dim].getConstantValue() == 1 ? blkShape[dim] : 1);
10311031
divisibility.push_back(
10321032
contiguity[dim] > 1
1033-
? std::min(ptrDivisibility,
1034-
strideInfo[dim == 0 ? 1 : 0].getDivisibility()[0])
1033+
? std::min(
1034+
ptrDivisibility,
1035+
(rank == 2 ? strideInfo[dim == 0 ? 1 : 0] : strideInfo[dim])
1036+
.getDivisibility()[0])
10351037
: 1);
10361038
constancy.push_back(1);
10371039
}
10381040

1039-
return AxisInfo(contiguity, divisibility, constancy);
1041+
auto axisInfo = AxisInfo(contiguity, divisibility, constancy);
1042+
1043+
LLVM_DEBUG({
1044+
std::string axisStr;
1045+
llvm::raw_string_ostream os(axisStr);
1046+
axisInfo.print(os);
1047+
LDBG("-- " << axisStr);
1048+
});
1049+
1050+
return axisInfo;
10401051
}
10411052
};
10421053

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 61 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -161,29 +161,33 @@ getWarpsPerCTA(const ArrayRef<int64_t> tensorShape,
161161

162162
// Contains some helper functions for both Load and Store conversions.
163163
struct LoadStoreConversionBase {
164-
explicit LoadStoreConversionBase(const triton::intel::TargetInfo &targetInfo,
165-
ModuleAxisInfoAnalysis &axisAnalysisPass)
164+
explicit LoadStoreConversionBase(
165+
const triton::intel::TargetInfo &targetInfo,
166+
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass)
166167
: targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {}
167168

168169
unsigned getContiguity(Value ptr) const {
169-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
170-
if (!tensorTy)
171-
return 1;
172-
return axisAnalysisPass.getPtrContiguity(ptr);
170+
return const_cast<triton::intel::ModuleAxisInfoAnalysis &>(axisAnalysisPass)
171+
.getPtrContiguity(ptr);
173172
}
174173

175174
unsigned getVectorSize(Value ptr) const {
176-
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
175+
auto tensorTy = getRankedTensorType(ptr.getType());
177176
if (!tensorTy)
178177
return 1;
179-
auto contiguity = getContiguity(ptr);
180-
auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy);
178+
179+
unsigned contiguity = getContiguity(ptr);
180+
unsigned pointeeBitWidth =
181+
isTensorPointerType(ptr.getType())
182+
? tensorTy.getElementType().getIntOrFloatBitWidth()
183+
: triton::getPointeeBitWidth(tensorTy);
181184
// The maximum vector size is 128 bits.
182185
return std::min<unsigned>(128 / pointeeBitWidth, contiguity);
183186
}
184187

185188
unsigned getMaskAlignment(Value mask) const {
186-
return axisAnalysisPass.getMaskAlignment(mask);
189+
return const_cast<triton::intel::ModuleAxisInfoAnalysis &>(axisAnalysisPass)
190+
.getMaskAlignment(mask);
187191
}
188192

189193
std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
@@ -289,7 +293,7 @@ struct LoadStoreConversionBase {
289293
}
290294

291295
protected:
292-
ModuleAxisInfoAnalysis &axisAnalysisPass;
296+
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass;
293297
const triton::intel::TargetInfo &targetInfo;
294298
};
295299

@@ -299,10 +303,11 @@ struct PrefetchOpConversion
299303
using ConvertTritonGPUOpToLLVMPattern<
300304
triton::gpu::intel::PrefetchOp>::ConvertTritonGPUOpToLLVMPattern;
301305

302-
PrefetchOpConversion(TritonGPUToLLVMTypeConverter &converter,
303-
const triton::intel::TargetInfo &targetInfo,
304-
ModuleAxisInfoAnalysis &axisAnalysisPass,
305-
PatternBenefit benefit)
306+
PrefetchOpConversion(
307+
TritonGPUToLLVMTypeConverter &converter,
308+
const triton::intel::TargetInfo &targetInfo,
309+
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
310+
PatternBenefit benefit)
306311
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::intel::PrefetchOp>(
307312
converter, benefit),
308313
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
@@ -475,10 +480,11 @@ struct LoadOpConversion
475480

476481
using ValueTable = std::map<std::pair<int, int>, Value>;
477482

478-
LoadOpConversion(TritonIntelGPUToLLVMTypeConverter &converter,
479-
const triton::intel::TargetInfo &targetInfo,
480-
ModuleAxisInfoAnalysis &axisAnalysisPass,
481-
PatternBenefit benefit)
483+
LoadOpConversion(
484+
TritonIntelGPUToLLVMTypeConverter &converter,
485+
const triton::intel::TargetInfo &targetInfo,
486+
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
487+
PatternBenefit benefit)
482488
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
483489
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
484490

@@ -824,37 +830,32 @@ struct LoadOpConversion
824830
Location loc = op->getLoc();
825831
auto typeConverter = getTypeConverter();
826832
MLIRContext *ctx = rewriter.getContext();
833+
Value ptr = op.getPtr();
834+
Value mask = op.getMask();
835+
Value llMask = adaptor.getMask();
827836

828837
// Determine the vectorization size
829838
Type valueElemTy =
830839
typeConverter->convertType(getElementTypeOrSelf(op.getType()));
831840
unsigned numElems = getTotalElemsPerThread(op.getType());
832-
unsigned vec = 1;
841+
unsigned vec = getVectorSize(ptr);
842+
if (llMask)
843+
vec = std::min<size_t>(vec, getMaskAlignment(mask));
833844

834845
SmallVector<Value> ptrElems, maskElems, otherElems;
835846
bool otherIsSplatConstInt = false;
836847
int64_t splatVal = 0;
837848

838-
if (isTensorPointerType(op.getPtr().getType())) {
839-
// TODO: (johnlu) set the vector size > 1; Need to prove the memory is
840-
// contiguous on the fast changing dim when fallback to gather load.
849+
if (isTensorPointerType(ptr.getType())) {
850+
// fallback to gather load.
841851
auto tensorType = cast<RankedTensorType>(op.getType());
842852
std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr(
843853
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
844854
op.getBoundaryCheck(), op.getPadding());
845855
} else {
846-
// original values
847-
Value ptr = op.getPtr();
848856
Value other = op.getOther();
849-
Value mask = op.getMask();
850-
851-
// adaptor values
852857
Value llPtr = adaptor.getPtr();
853-
Value llMask = adaptor.getMask();
854858
Value llOther = adaptor.getOther();
855-
vec = getVectorSize(ptr);
856-
if (llMask)
857-
vec = std::min<size_t>(vec, getMaskAlignment(mask));
858859

859860
// Get the LLVM values for pointers
860861
ptrElems = unpackLLElements(loc, llPtr, rewriter);
@@ -987,10 +988,11 @@ struct StoreOpConversion
987988
using ConvertTritonGPUOpToLLVMPattern<
988989
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
989990

990-
StoreOpConversion(TritonIntelGPUToLLVMTypeConverter &converter,
991-
const triton::intel::TargetInfo &targetInfo,
992-
ModuleAxisInfoAnalysis &axisAnalysisPass,
993-
PatternBenefit benefit)
991+
StoreOpConversion(
992+
TritonIntelGPUToLLVMTypeConverter &converter,
993+
const triton::intel::TargetInfo &targetInfo,
994+
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
995+
PatternBenefit benefit)
994996
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
995997
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
996998

@@ -1128,14 +1130,20 @@ struct StoreOpConversion
11281130
return success();
11291131

11301132
Location loc = op->getLoc();
1133+
auto *typeConverter = getTypeConverter();
11311134
MLIRContext *ctx = rewriter.getContext();
11321135
Value ptr = op.getPtr();
1133-
Value value = op.getValue();
1134-
Type valueTy = value.getType();
1136+
Value mask = op.getMask();
1137+
Value llMask = adaptor.getMask();
1138+
1139+
// Determine the vectorization size
1140+
Type valueTy = op.getValue().getType();
11351141
Type valueElemTy =
11361142
typeConverter->convertType(getElementTypeOrSelf(valueTy));
11371143
SmallVector<Value> ptrElems, maskElems;
1138-
unsigned vec = 1;
1144+
unsigned vec = getVectorSize(ptr);
1145+
if (llMask)
1146+
vec = std::min<size_t>(vec, getMaskAlignment(mask));
11391147

11401148
if (isTensorPointerType(ptr.getType())) {
11411149
// fallback to scatter store.
@@ -1146,20 +1154,9 @@ struct StoreOpConversion
11461154
op.getBoundaryCheck());
11471155
} else {
11481156
Value llPtr = adaptor.getPtr();
1149-
Value llMask = adaptor.getMask();
1150-
1151-
vec = getVectorSize(ptr);
1152-
11531157
ptrElems = unpackLLElements(loc, llPtr, rewriter);
1154-
1155-
// Determine the vectorization size
1156-
if (llMask) {
1157-
Value mask = op.getMask();
1158+
if (llMask)
11581159
maskElems = unpackLLElements(loc, llMask, rewriter);
1159-
1160-
unsigned maskAlign = getMaskAlignment(mask);
1161-
vec = std::min(vec, maskAlign);
1162-
}
11631160
}
11641161

11651162
Value llValue = adaptor.getValue();
@@ -1168,7 +1165,7 @@ struct StoreOpConversion
11681165
assert(!maskElems.size() ||
11691166
valueElems.size() == maskElems.size() && "Mask size mismatch");
11701167

1171-
Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
1168+
mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
11721169
const size_t dtsize =
11731170
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
11741171
const size_t valueElemNBits = dtsize * 8;
@@ -1247,10 +1244,11 @@ struct AtomicCASOpConversion
12471244
using ConvertTritonGPUOpToLLVMPattern<
12481245
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
12491246

1250-
AtomicCASOpConversion(TritonIntelGPUToLLVMTypeConverter &converter,
1251-
const triton::intel::TargetInfo &targetInfo,
1252-
ModuleAxisInfoAnalysis &axisAnalysisPass,
1253-
PatternBenefit benefit)
1247+
AtomicCASOpConversion(
1248+
TritonIntelGPUToLLVMTypeConverter &converter,
1249+
const triton::intel::TargetInfo &targetInfo,
1250+
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
1251+
PatternBenefit benefit)
12541252
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(converter,
12551253
benefit),
12561254
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
@@ -1364,10 +1362,11 @@ struct AtomicRMWOpConversion
13641362
using ConvertTritonGPUOpToLLVMPattern<
13651363
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
13661364

1367-
AtomicRMWOpConversion(TritonIntelGPUToLLVMTypeConverter &converter,
1368-
const triton::intel::TargetInfo &targetInfo,
1369-
ModuleAxisInfoAnalysis &axisAnalysisPass,
1370-
PatternBenefit benefit)
1365+
AtomicRMWOpConversion(
1366+
TritonIntelGPUToLLVMTypeConverter &converter,
1367+
const triton::intel::TargetInfo &targetInfo,
1368+
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
1369+
PatternBenefit benefit)
13711370
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
13721371
benefit),
13731372
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
@@ -1627,7 +1626,8 @@ struct AtomicRMWOpConversion
16271626
void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns(
16281627
TritonIntelGPUToLLVMTypeConverter &typeConverter,
16291628
const TargetInfo &targetInfo, RewritePatternSet &patterns,
1630-
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
1629+
const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
1630+
PatternBenefit benefit) {
16311631
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
16321632
StoreOpConversion, PrefetchOpConversion>(
16331633
typeConverter, targetInfo, axisInfoAnalysis, benefit);

third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
5353
void populateLoadStoreOpToLLVMPatterns(
5454
TritonIntelGPUToLLVMTypeConverter &typeConverter,
5555
const TargetInfo &targetInfo, RewritePatternSet &patterns,
56-
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit);
56+
const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit);
5757

5858
void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
5959
RewritePatternSet &patterns,

0 commit comments

Comments
 (0)