Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
c7fe682
Improve axis analysis to handle tt.make_tensor_ptr
etiotto Oct 9, 2024
ad3888f
Merge branch 'main' into etiotto/axis_analysis_make_tensor_ptr
etiotto Oct 9, 2024
a7a9b06
Merge branch 'main' into etiotto/axis_analysis_make_tensor_ptr
etiotto Oct 10, 2024
6bddd5f
Merge branch 'main' into etiotto/axis_analysis_make_tensor_ptr
etiotto Oct 10, 2024
4ad4f1a
Merge branch 'main' into etiotto/axis_analysis_make_tensor_ptr
etiotto Oct 10, 2024
4dc1cf1
WIP: Coalescing for block ptrs
etiotto Oct 16, 2024
fa53ced
Fix pre_commit
etiotto Oct 16, 2024
049ddb8
Merge branch 'main' into etiotto/coalesce_for_block_ptr
etiotto Oct 17, 2024
041e2da
Merge branch 'main' into etiotto/coalesce_for_block_ptr
etiotto Oct 17, 2024
5a6cf81
Fix functional problem and add lit test
etiotto Oct 17, 2024
2546665
Fix pre_commit
etiotto Oct 17, 2024
4d5dc49
Reenable rewrite tensor ptr
etiotto Oct 17, 2024
c3fdbba
Fix test_core regression
etiotto Oct 18, 2024
d9de8e7
Fix tutorial assertion
etiotto Oct 18, 2024
949256e
Refactor
etiotto Oct 18, 2024
754ec70
Cleanup
etiotto Oct 18, 2024
469407b
Cleanup
etiotto Oct 18, 2024
9f4f98d
Extend axis info analysis to more block ptrs
etiotto Oct 21, 2024
a40844b
Merge branch 'main' into etiotto/coalesce_for_block_ptr
etiotto Oct 21, 2024
bb9b4c3
Address code review comments
etiotto Oct 22, 2024
8d9a158
Remove unrelated change
etiotto Oct 22, 2024
6529f04
Remove unrelated change
etiotto Oct 22, 2024
0aa334b
Remove unrelated change
etiotto Oct 22, 2024
547d6fa
Fix pre_commit
etiotto Oct 22, 2024
6566f6c
Merge branch 'main' into etiotto/coalesce_for_block_ptr
etiotto Oct 23, 2024
2f97c1a
Address code review comments
etiotto Oct 23, 2024
95f5832
Fix pre_commit
etiotto Oct 23, 2024
0887245
Merge branch 'main' into etiottoremove_layout_conv
etiotto Oct 24, 2024
3636bef
Make isExpensiveLoadOrStore consider blocked pointers load and stores
etiotto Oct 24, 2024
db2193e
Make isExpensiveLoadOrStore consider blocked pointers load and stores
etiotto Oct 25, 2024
eeda8e9
Merge branch 'main' into etiottoremove_layout_conv
etiotto Oct 25, 2024
7c9a0f9
MaterializeBlockPointer fix for GEMM with 1st operand transposed
etiotto Oct 25, 2024
cbc630b
MaterializeBlockPointer fix for GEMM with 1st operand transposed
etiotto Oct 25, 2024
0215a16
Fix unit tests
etiotto Oct 28, 2024
ae3d625
Fix performance regression for gemm-preop-exp
etiotto Oct 28, 2024
22b7ec9
Reduce PR footprint
etiotto Oct 28, 2024
4991020
Remove RewriteTensorPointer from the optimization pipeline
etiotto Oct 28, 2024
9521870
Disable address payload opt experiment
etiotto Oct 30, 2024
a96efb5
Merge branch 'main' into etiotto.remove_rewrite_tensor_ptr
etiotto Oct 31, 2024
00f8432
Fix test_block_pointer.py:test_block_copy
etiotto Oct 31, 2024
a21d58d
Merge branch 'main' into etiotto.remove_rewrite_tensor_ptr
etiotto Nov 1, 2024
17f5b25
Address code review comments
etiotto Nov 1, 2024
0b21a82
Address code review comments
etiotto Nov 1, 2024
2d22907
Add vectorization support for store as well
etiotto Nov 1, 2024
c96c236
Merge branch 'main' into etiotto.remove_rewrite_tensor_ptr
etiotto Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion third_party/intel/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,8 @@ def make_ttgir(mod, metadata, opt, properties):
intel.passes.ttgpuir.add_accelerate_matmul(pm)
intel.passes.ttgpuir.add_remove_layout_conversions(pm)
intel.passes.ttgpuir.add_materialize_block_pointer(pm)
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
if os.getenv("TRITON_INTEL_REWRITE_TENSOR_POINTER", "0") == "1":
intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm)
intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False)

intel.passes.ttgpuir.add_coalesce(pm)
Expand Down
1 change: 0 additions & 1 deletion third_party/intel/include/Analysis/AxisInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ namespace mlir::triton::intel {
// axis info based on the axis info of all the callers. In the future, we can
// perform optimization using function cloning so that each call site will have
// unique axis info.

class ModuleAxisInfoAnalysis : public triton::ModuleAxisInfoAnalysis {
public:
explicit ModuleAxisInfoAnalysis(ModuleOp moduleOp)
Expand Down
17 changes: 14 additions & 3 deletions third_party/intel/lib/Analysis/AxisInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,13 +1030,24 @@ class MakeTensorPtrOpAxisInfoVisitor final
strideInfo[dim].getConstantValue() == 1 ? blkShape[dim] : 1);
divisibility.push_back(
contiguity[dim] > 1
? std::min(ptrDivisibility,
strideInfo[dim == 0 ? 1 : 0].getDivisibility()[0])
? std::min(
ptrDivisibility,
(rank == 2 ? strideInfo[dim == 0 ? 1 : 0] : strideInfo[dim])
.getDivisibility()[0])
: 1);
constancy.push_back(1);
}

return AxisInfo(contiguity, divisibility, constancy);
auto axisInfo = AxisInfo(contiguity, divisibility, constancy);

LLVM_DEBUG({
std::string axisStr;
llvm::raw_string_ostream os(axisStr);
axisInfo.print(os);
LDBG("-- " << axisStr);
});

return axisInfo;
}
};

Expand Down
122 changes: 61 additions & 61 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,29 +161,33 @@ getWarpsPerCTA(const ArrayRef<int64_t> tensorShape,

// Contains some helper functions for both Load and Store conversions.
struct LoadStoreConversionBase {
explicit LoadStoreConversionBase(const triton::intel::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass)
explicit LoadStoreConversionBase(
const triton::intel::TargetInfo &targetInfo,
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass)
: targetInfo(targetInfo), axisAnalysisPass(axisAnalysisPass) {}

unsigned getContiguity(Value ptr) const {
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
if (!tensorTy)
return 1;
return axisAnalysisPass.getPtrContiguity(ptr);
return const_cast<triton::intel::ModuleAxisInfoAnalysis &>(axisAnalysisPass)
.getPtrContiguity(ptr);
}

unsigned getVectorSize(Value ptr) const {
auto tensorTy = dyn_cast<RankedTensorType>(ptr.getType());
auto tensorTy = getRankedTensorType(ptr.getType());
if (!tensorTy)
return 1;
auto contiguity = getContiguity(ptr);
auto pointeeBitWidth = triton::getPointeeBitWidth(tensorTy);

unsigned contiguity = getContiguity(ptr);
unsigned pointeeBitWidth =
isTensorPointerType(ptr.getType())
? tensorTy.getElementType().getIntOrFloatBitWidth()
: triton::getPointeeBitWidth(tensorTy);
// The maximum vector size is 128 bits.
return std::min<unsigned>(128 / pointeeBitWidth, contiguity);
}

unsigned getMaskAlignment(Value mask) const {
return axisAnalysisPass.getMaskAlignment(mask);
return const_cast<triton::intel::ModuleAxisInfoAnalysis &>(axisAnalysisPass)
.getMaskAlignment(mask);
}

std::tuple<SmallVector<Value>, SmallVector<Value>, SmallVector<Value>>
Expand Down Expand Up @@ -289,7 +293,7 @@ struct LoadStoreConversionBase {
}

protected:
ModuleAxisInfoAnalysis &axisAnalysisPass;
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass;
const triton::intel::TargetInfo &targetInfo;
};

Expand All @@ -299,10 +303,11 @@ struct PrefetchOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::gpu::intel::PrefetchOp>::ConvertTritonGPUOpToLLVMPattern;

PrefetchOpConversion(TritonGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
PrefetchOpConversion(
TritonGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::intel::PrefetchOp>(
converter, benefit),
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
Expand Down Expand Up @@ -475,10 +480,11 @@ struct LoadOpConversion

using ValueTable = std::map<std::pair<int, int>, Value>;

LoadOpConversion(TritonIntelGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
LoadOpConversion(
TritonIntelGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}

Expand Down Expand Up @@ -824,37 +830,32 @@ struct LoadOpConversion
Location loc = op->getLoc();
auto typeConverter = getTypeConverter();
MLIRContext *ctx = rewriter.getContext();
Value ptr = op.getPtr();
Value mask = op.getMask();
Value llMask = adaptor.getMask();

// Determine the vectorization size
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(op.getType()));
unsigned numElems = getTotalElemsPerThread(op.getType());
unsigned vec = 1;
unsigned vec = getVectorSize(ptr);
if (llMask)
vec = std::min<size_t>(vec, getMaskAlignment(mask));

SmallVector<Value> ptrElems, maskElems, otherElems;
bool otherIsSplatConstInt = false;
int64_t splatVal = 0;

if (isTensorPointerType(op.getPtr().getType())) {
// TODO: (johnlu) set the vector size > 1; Need to prove the memory is
// contiguous on the fast changing dim when fallback to gather load.
if (isTensorPointerType(ptr.getType())) {
// fallback to gather load.
auto tensorType = cast<RankedTensorType>(op.getType());
std::tie(ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr(
loc, adaptor.getPtr(), tensorType, valueElemTy, rewriter,
op.getBoundaryCheck(), op.getPadding());
} else {
// original values
Value ptr = op.getPtr();
Value other = op.getOther();
Value mask = op.getMask();

// adaptor values
Value llPtr = adaptor.getPtr();
Value llMask = adaptor.getMask();
Value llOther = adaptor.getOther();
vec = getVectorSize(ptr);
if (llMask)
vec = std::min<size_t>(vec, getMaskAlignment(mask));

// Get the LLVM values for pointers
ptrElems = unpackLLElements(loc, llPtr, rewriter);
Expand Down Expand Up @@ -987,10 +988,11 @@ struct StoreOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;

StoreOpConversion(TritonIntelGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
StoreOpConversion(
TritonIntelGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}

Expand Down Expand Up @@ -1128,14 +1130,20 @@ struct StoreOpConversion
return success();

Location loc = op->getLoc();
auto *typeConverter = getTypeConverter();
MLIRContext *ctx = rewriter.getContext();
Value ptr = op.getPtr();
Value value = op.getValue();
Type valueTy = value.getType();
Value mask = op.getMask();
Value llMask = adaptor.getMask();

// Determine the vectorization size
Type valueTy = op.getValue().getType();
Type valueElemTy =
typeConverter->convertType(getElementTypeOrSelf(valueTy));
SmallVector<Value> ptrElems, maskElems;
unsigned vec = 1;
unsigned vec = getVectorSize(ptr);
if (llMask)
vec = std::min<size_t>(vec, getMaskAlignment(mask));

if (isTensorPointerType(ptr.getType())) {
// fallback to scatter store.
Expand All @@ -1146,20 +1154,9 @@ struct StoreOpConversion
op.getBoundaryCheck());
} else {
Value llPtr = adaptor.getPtr();
Value llMask = adaptor.getMask();

vec = getVectorSize(ptr);

ptrElems = unpackLLElements(loc, llPtr, rewriter);

// Determine the vectorization size
if (llMask) {
Value mask = op.getMask();
if (llMask)
maskElems = unpackLLElements(loc, llMask, rewriter);

unsigned maskAlign = getMaskAlignment(mask);
vec = std::min(vec, maskAlign);
}
}

Value llValue = adaptor.getValue();
Expand All @@ -1168,7 +1165,7 @@ struct StoreOpConversion
assert(!maskElems.size() ||
valueElems.size() == maskElems.size() && "Mask size mismatch");

Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
const size_t dtsize =
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
const size_t valueElemNBits = dtsize * 8;
Expand Down Expand Up @@ -1247,10 +1244,11 @@ struct AtomicCASOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;

AtomicCASOpConversion(TritonIntelGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
AtomicCASOpConversion(
TritonIntelGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(converter,
benefit),
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
Expand Down Expand Up @@ -1364,10 +1362,11 @@ struct AtomicRMWOpConversion
using ConvertTritonGPUOpToLLVMPattern<
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;

AtomicRMWOpConversion(TritonIntelGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
AtomicRMWOpConversion(
TritonIntelGPUToLLVMTypeConverter &converter,
const triton::intel::TargetInfo &targetInfo,
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
PatternBenefit benefit)
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
benefit),
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
Expand Down Expand Up @@ -1627,7 +1626,8 @@ struct AtomicRMWOpConversion
void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns(
TritonIntelGPUToLLVMTypeConverter &typeConverter,
const TargetInfo &targetInfo, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit) {
const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
PatternBenefit benefit) {
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, LoadOpConversion,
StoreOpConversion, PrefetchOpConversion>(
typeConverter, targetInfo, axisInfoAnalysis, benefit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
void populateLoadStoreOpToLLVMPatterns(
TritonIntelGPUToLLVMTypeConverter &typeConverter,
const TargetInfo &targetInfo, RewritePatternSet &patterns,
ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit);
const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit);

void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
Expand Down