Skip to content

Commit b1fd14a

Browse files
Move convertTritonPointerType to third_party/intel (#3349)
The conversion of pointers to tensors is removed upstream in commit 5c2e2bc, but it is required by Intel for handling blocked pointer, so this PR moves the conversion to `third_party/intel`. Closes #3228. Signed-off-by: Whitney Tsang <[email protected]>
1 parent fa0f511 commit b1fd14a

File tree

3 files changed

+29
-28
lines changed

3 files changed

+29
-28
lines changed

include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
1818
const TargetInfoBase &targetInfo,
1919
const DataLayoutAnalysis *analysis = nullptr);
2020

21-
Type convertTritonPointerType(triton::PointerType type);
2221
Type convertTritonTensorType(RankedTensorType type,
2322
const TargetInfoBase &targetInfo);
2423
Type convertMemDescType(triton::gpu::MemDescType type,

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
1515
MLIRContext *ctx, LowerToLLVMOptions &options,
1616
const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis)
1717
: LLVMTypeConverter(ctx, options, analysis) {
18-
addConversion([&](triton::PointerType type) -> std::optional<Type> {
19-
return convertTritonPointerType(type);
18+
addConversion([ctx](triton::PointerType type) -> std::optional<Type> {
19+
return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace());
2020
});
2121
addConversion([ctx](TensorDescType type) -> std::optional<Type> {
2222
return LLVM::LLVMPointerType::get(ctx, 1);
@@ -35,31 +35,6 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
3535
mlir::Float8E5M2Type, mlir::Float8E5M2FNUZType>();
3636
}
3737

38-
Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
39-
triton::PointerType type) {
40-
auto ctx = type.getContext();
41-
auto pointeeType = type.getPointeeType();
42-
if (isa<RankedTensorType>(pointeeType)) {
43-
auto rankedTensorType = cast<RankedTensorType>(pointeeType);
44-
// struct { offset0, offset1, shape0, shape1, stride0,
45-
// stride1, base_ptr};
46-
auto eleType = rankedTensorType.getElementType();
47-
auto shape = rankedTensorType.getShape();
48-
SmallVector<Type, 4> types;
49-
// offsets
50-
for (size_t i = 0; i < shape.size(); ++i)
51-
types.push_back(IntegerType::get(ctx, 32));
52-
// shapes, strides
53-
for (size_t i = 0; i < 2 * shape.size(); ++i)
54-
types.push_back(IntegerType::get(ctx, 64));
55-
56-
types.push_back(LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()));
57-
58-
return LLVM::LLVMStructType::getLiteral(ctx, types);
59-
}
60-
return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace());
61-
}
62-
6338
Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
6439
RankedTensorType type, const TargetInfoBase &targetInfo) {
6540
auto ctx = type.getContext();

third_party/intel/lib/TritonIntelGPUToLLVM/TypeConverter.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,38 @@
99
#include "intel/include/TritonIntelGPUToLLVM/TypeConverter.h"
1010
#include "triton/Tools/Sys/GetEnv.hpp"
1111

12+
static Type convertTritonPointerType(triton::PointerType type) {
13+
auto ctx = type.getContext();
14+
auto pointeeType = type.getPointeeType();
15+
if (isa<RankedTensorType>(pointeeType)) {
16+
auto rankedTensorType = cast<RankedTensorType>(pointeeType);
17+
// struct { offset0, offset1, shape0, shape1, stride0,
18+
// stride1, base_ptr};
19+
auto eleType = rankedTensorType.getElementType();
20+
auto shape = rankedTensorType.getShape();
21+
SmallVector<Type, 4> types;
22+
// offsets
23+
for (size_t i = 0; i < shape.size(); ++i)
24+
types.push_back(IntegerType::get(ctx, 32));
25+
// shapes, strides
26+
for (size_t i = 0; i < 2 * shape.size(); ++i)
27+
types.push_back(IntegerType::get(ctx, 64));
28+
29+
types.push_back(LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()));
30+
31+
return LLVM::LLVMStructType::getLiteral(ctx, types);
32+
}
33+
return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace());
34+
}
35+
1236
TritonIntelGPUToLLVMTypeConverter::TritonIntelGPUToLLVMTypeConverter(
1337
MLIRContext *ctx, LowerToLLVMOptions &option,
1438
const TargetInfoBase &targetInfo, bool isAdvancedPathEnabled,
1539
const DataLayoutAnalysis *analysis)
1640
: TritonGPUToLLVMTypeConverter(ctx, option, targetInfo, analysis) {
41+
addConversion([&](triton::PointerType type) -> std::optional<Type> {
42+
return convertTritonPointerType(type);
43+
});
1744
// Augment/overwrite type conversions required for the Intel conversion
1845
// passes.
1946
if (isAdvancedPathEnabled) {

0 commit comments

Comments
 (0)