Skip to content

Commit bea6f7c

Browse files
Revert partial [BACKEND][NFC] Clean up TritonGPU to LLVM type conversion (#5647)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 37f35e9 commit bea6f7c

File tree

2 files changed

+28
-2
lines changed

2 files changed

+28
-2
lines changed

include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
1818
const TargetInfoBase &targetInfo,
1919
const DataLayoutAnalysis *analysis = nullptr);
2020

21+
Type convertTritonPointerType(triton::PointerType type);
2122
Type convertTritonTensorType(RankedTensorType type,
2223
const TargetInfoBase &targetInfo);
2324
Type convertMemDescType(triton::gpu::MemDescType type,

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
1616
MLIRContext *ctx, LowerToLLVMOptions &options,
1717
const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis)
1818
: LLVMTypeConverter(ctx, options, analysis) {
19-
addConversion([ctx](triton::PointerType type) -> std::optional<Type> {
20-
return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace());
19+
addConversion([&](triton::PointerType type) -> std::optional<Type> {
20+
return convertTritonPointerType(type);
2121
});
2222
addConversion([ctx](TensorDescType type) -> std::optional<Type> {
2323
return LLVM::LLVMPointerType::get(ctx, 1);
@@ -36,6 +36,31 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
3636
mlir::Float8E5M2Type, mlir::Float8E5M2FNUZType>();
3737
}
3838

39+
Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
40+
triton::PointerType type) {
41+
auto ctx = type.getContext();
42+
auto pointeeType = type.getPointeeType();
43+
if (isa<RankedTensorType>(pointeeType)) {
44+
auto rankedTensorType = cast<RankedTensorType>(pointeeType);
45+
// struct { offset0, offset1, shape0, shape1, stride0,
46+
// stride1, base_ptr};
47+
auto eleType = rankedTensorType.getElementType();
48+
auto shape = rankedTensorType.getShape();
49+
SmallVector<Type, 4> types;
50+
// offsets
51+
for (size_t i = 0; i < shape.size(); ++i)
52+
types.push_back(IntegerType::get(ctx, 32));
53+
// shapes, strides
54+
for (size_t i = 0; i < 2 * shape.size(); ++i)
55+
types.push_back(IntegerType::get(ctx, 64));
56+
57+
types.push_back(LLVM::LLVMPointerType::get(ctx, type.getAddressSpace()));
58+
59+
return LLVM::LLVMStructType::getLiteral(ctx, types);
60+
}
61+
return LLVM::LLVMPointerType::get(ctx, type.getAddressSpace());
62+
}
63+
3964
Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
4065
RankedTensorType type, const TargetInfoBase &targetInfo) {
4166
auto ctx = type.getContext();

0 commit comments

Comments
 (0)