@@ -16,8 +16,8 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
1616 MLIRContext *ctx, LowerToLLVMOptions &options,
1717 const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis)
1818 : LLVMTypeConverter(ctx, options, analysis) {
19- addConversion ([ctx ](triton::PointerType type) -> std::optional<Type> {
20- return LLVM::LLVMPointerType::get (ctx, type. getAddressSpace () );
19+ addConversion ([& ](triton::PointerType type) -> std::optional<Type> {
20+ return convertTritonPointerType ( type);
2121 });
2222 addConversion ([ctx](TensorDescType type) -> std::optional<Type> {
2323 return LLVM::LLVMPointerType::get (ctx, 1 );
@@ -36,6 +36,31 @@ TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
3636 mlir::Float8E5M2Type, mlir::Float8E5M2FNUZType>();
3737}
3838
39+ Type TritonGPUToLLVMTypeConverter::convertTritonPointerType (
40+ triton::PointerType type) {
41+ auto ctx = type.getContext ();
42+ auto pointeeType = type.getPointeeType ();
43+ if (isa<RankedTensorType>(pointeeType)) {
44+ auto rankedTensorType = cast<RankedTensorType>(pointeeType);
45+ // struct { offset0, offset1, shape0, shape1, stride0,
46+ // stride1, base_ptr};
47+ auto eleType = rankedTensorType.getElementType ();
48+ auto shape = rankedTensorType.getShape ();
49+ SmallVector<Type, 4 > types;
50+ // offsets
51+ for (size_t i = 0 ; i < shape.size (); ++i)
52+ types.push_back (IntegerType::get (ctx, 32 ));
53+ // shapes, strides
54+ for (size_t i = 0 ; i < 2 * shape.size (); ++i)
55+ types.push_back (IntegerType::get (ctx, 64 ));
56+
57+ types.push_back (LLVM::LLVMPointerType::get (ctx, type.getAddressSpace ()));
58+
59+ return LLVM::LLVMStructType::getLiteral (ctx, types);
60+ }
61+ return LLVM::LLVMPointerType::get (ctx, type.getAddressSpace ());
62+ }
63+
3964Type TritonGPUToLLVMTypeConverter::convertTritonTensorType (
4065 RankedTensorType type, const TargetInfoBase &targetInfo) {
4166 auto ctx = type.getContext ();
0 commit comments