88#include " mlir/IR/BuiltinTypes.h"
99#include " mlir/IR/PatternMatch.h"
1010#include " mlir/IR/Value.h"
11+ #include " llvm/Support/Debug.h"
12+ #include " llvm/Support/raw_ostream.h"
1113
1214#include " mlir/Transforms/DialectConversion.h"
1315#include " triton-shared/Conversion/TPtrToLLVM/TPtrToLLVM.h"
@@ -24,24 +26,6 @@ static bool isOneToOneCast(UnrealizedConversionCastOp op) {
2426 return (op.getInputs ().size () == 1 && op->getNumResults () == 1 );
2527}
2628
27- Type convertMemRefType (MemRefType type) {
28- auto ctx = type.getContext ();
29- auto rank = type.getShape ().size ();
30- auto i64Ty = IntegerType::get (ctx, 64 );
31- SmallVector<Type, 5 > types;
32-
33- // struct { ptr base_ptr, ptr aligned_ptr, i64 offset,
34- // array<rank x i64> sizes, array<rank x i64> strides }
35- types.push_back (LLVM::LLVMPointerType::get (ctx)); // base pointer
36- types.push_back (LLVM::LLVMPointerType::get (ctx)); // aligned pointer
37- types.push_back (i64Ty); // offset
38- types.push_back (LLVM::LLVMArrayType::get (ctx, i64Ty, rank)); // sizes
39- types.push_back (LLVM::LLVMArrayType::get (ctx, i64Ty, rank)); // strides
40-
41- Type ret = LLVM::LLVMStructType::getLiteral (ctx, types);
42- LDBG (" From MemrefConverter convertMemRefType: " << type << " -> " << ret);
43- return ret;
44- }
4529
4630// PtrAddOp -> llvm.getelementptr conversion
4731struct PtrAddConverter : OpConversionPattern<tptr::PtrAddOp> {
@@ -109,7 +93,7 @@ struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
10993 LogicalResult
11094 matchAndRewrite (tptr::ToMemrefOp op, OpAdaptor adaptor,
11195 ConversionPatternRewriter &rewriter) const override {
112- LDBG (" matchAndRewrite: to_memref " << op);
96+ LDBG (" matchAndRewrite: to_memref (before) " << op);
11397
11498 auto input = adaptor.getArg ();
11599
@@ -125,15 +109,15 @@ struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
125109 }
126110 }
127111
128- Type targetType = convertMemRefType (op.getType ());
112+ Type targetType = getTypeConverter ()->convertType (cast<MemRefType>(op.getType ()));
113+ LDBG (" matchAndRewrite: to_memref (typeconverted) " << cast<MemRefType>(op.getType ()) << " -> " << targetType);
129114 if (!targetType) {
130115 return rewriter.notifyMatchFailure (op, " failed to convert memref type" );
131116 }
132117
133118 auto loc = op.getLoc ();
134119 auto i64Ty = rewriter.getIntegerType (64 );
135- auto memrefType = cast<MemRefType>(op.getType ());
136- auto shape = memrefType.getShape ();
120+ auto shape = cast<MemRefType>(op.getType ()).getShape ();
137121 auto rank = shape.size ();
138122
139123 Value result = rewriter.create <LLVM::UndefOp>(loc, targetType);
@@ -167,7 +151,7 @@ struct ToMemrefConverter : OpConversionPattern<tptr::ToMemrefOp> {
167151 }
168152
169153 rewriter.replaceOp (op, result);
170- LDBG (" matchAndRewrite: to_memref done " );
154+ LDBG (" matchAndRewrite: to_memref (after) -> " << result );
171155 return success ();
172156 }
173157};
@@ -179,16 +163,17 @@ struct FromMemrefConverter : OpConversionPattern<tptr::FromMemrefOp> {
179163 LogicalResult
180164 matchAndRewrite (tptr::FromMemrefOp op, OpAdaptor adaptor,
181165 ConversionPatternRewriter &rewriter) const override {
182- LDBG (" matchAndRewrite: from_memref " << op);
166+ LDBG (" matchAndRewrite: from_memref (before) " << op);
183167
184168 Value input = adaptor.getInput ();
185169 if (isa<MemRefType>(input.getType ())) {
186- input =
187- rewriter
188- .create <UnrealizedConversionCastOp>(
189- op.getLoc (),
190- convertMemRefType (cast<MemRefType>(input.getType ())), input)
191- .getResult (0 );
170+ input = rewriter
171+ .create <UnrealizedConversionCastOp>(
172+ op.getLoc (),
173+ getTypeConverter ()->convertType (
174+ cast<MemRefType>(input.getType ())),
175+ input)
176+ .getResult (0 );
192177 }
193178
194179 // Extract base_ptr (index 0)
@@ -197,7 +182,7 @@ struct FromMemrefConverter : OpConversionPattern<tptr::FromMemrefOp> {
197182 op.getLoc (), resultType, input, rewriter.getDenseI64ArrayAttr ({0 }));
198183
199184 rewriter.replaceOp (op, extractOp);
200- LDBG (" matchAndRewrite: from_memref done " );
185+ LDBG (" matchAndRewrite: from_memref (after) -> " << extractOp );
201186 return success ();
202187 }
203188};
@@ -225,16 +210,14 @@ struct UnrealizedCastConverter
225210 }
226211
227212 if (isa<ptr::PtrType>(outputType) ||
228- (isa<LLVM::LLVMPointerType>(inputType) &&
229- isa<ptr::PtrType>(outputType)) ||
230- (isa<LLVM::LLVMPointerType>(inputType) &&
231- isa<MemRefType>(outputType))) {
213+ (isa<LLVM::LLVMPointerType>(inputType) && isa<MemRefType>(outputType))) {
214+ LDBG (" UnrealizedCast (reject): unsafe pointer conversion " << op);
232215 return rewriter.notifyMatchFailure (op, " unsafe pointer conversion" );
233216 }
234217
235218 if ((isa<LLVM::LLVMStructType>(inputType) && isa<MemRefType>(outputType)) ||
236219 (isa<MemRefType>(inputType) && isa<LLVM::LLVMStructType>(outputType))) {
237- LDBG (" matchAndRewrite: replace with input: " << op << " -> " << input);
220+ LDBG (" matchAndRewrite: UnrealizedCast (after) " << op << " -> " << input);
238221 rewriter.replaceOp (op, input);
239222 return success ();
240223 }
@@ -276,7 +259,7 @@ struct ConvertControlFlowOp : OpConversionPattern<cf::CondBranchOp> {
276259 LogicalResult
277260 matchAndRewrite (cf::CondBranchOp op, OpAdaptor adaptor,
278261 ConversionPatternRewriter &rewriter) const override {
279- LDBG (" matchAndRewrite: cond_branch " << op);
262+ LDBG (" matchAndRewrite: cond_branch (before) " << op);
280263
281264 if (failed (legalizeBlockArguments (*op.getTrueDest (), op, rewriter,
282265 *getTypeConverter ())) ||
@@ -285,12 +268,12 @@ struct ConvertControlFlowOp : OpConversionPattern<cf::CondBranchOp> {
285268 return failure ();
286269 }
287270
288- rewriter.replaceOpWithNewOp <cf::CondBranchOp>(
271+ auto newOp = rewriter.replaceOpWithNewOp <cf::CondBranchOp>(
289272 op, adaptor.getCondition (), op.getTrueDest (),
290273 adaptor.getTrueDestOperands (), op.getFalseDest (),
291274 adaptor.getFalseDestOperands ());
292275
293- LDBG (" matchAndRewrite: cond_branch done " );
276+ LDBG (" matchAndRewrite: cond_branch (after) -> " << newOp );
294277 return success ();
295278 }
296279};
@@ -302,16 +285,17 @@ struct ConvertBranchOp : OpConversionPattern<cf::BranchOp> {
302285 LogicalResult
303286 matchAndRewrite (cf::BranchOp op, OpAdaptor adaptor,
304287 ConversionPatternRewriter &rewriter) const override {
305- LDBG (" matchAndRewrite: cf.br " << op);
288+ LDBG (" matchAndRewrite: cf.br (before) " << op);
306289
307290 if (failed (legalizeBlockArguments (*op.getDest (), op, rewriter,
308291 *getTypeConverter ()))) {
309292 return failure ();
310293 }
311294
312- rewriter.replaceOpWithNewOp <cf::BranchOp>(op, op.getDest (),
313- adaptor.getDestOperands ());
314- LDBG (" matchAndRewrite: cf.br done" );
295+ auto newOp =
296+ rewriter.replaceOpWithNewOp <cf::BranchOp>(op, op.getDest (),
297+ adaptor.getDestOperands ());
298+ LDBG (" matchAndRewrite: cf.br (after) -> " << newOp);
315299 return success ();
316300 }
317301};
@@ -324,7 +308,7 @@ struct MemRefAllocConverter : OpConversionPattern<memref::AllocOp> {
324308 LogicalResult
325309 matchAndRewrite (memref::AllocOp op, OpAdaptor adaptor,
326310 ConversionPatternRewriter &rewriter) const override {
327- LDBG (" matchAndRewrite: memref.alloc " << op);
311+ LDBG (" matchAndRewrite: memref.alloc (before) " << op);
328312
329313 auto oldMemRefType = op.getType ();
330314 auto elementType = oldMemRefType.getElementType ();
@@ -395,7 +379,7 @@ struct MemRefAllocConverter : OpConversionPattern<memref::AllocOp> {
395379 }
396380
397381 rewriter.replaceOp (op, result);
398- LDBG (" matchAndRewrite: memref.alloc done -> LLVM struct " );
382+ LDBG (" matchAndRewrite: memref.alloc (after) -> " << result );
399383 return success ();
400384 }
401385};
@@ -407,7 +391,7 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
407391 LogicalResult
408392 matchAndRewrite (memref::StoreOp op, OpAdaptor adaptor,
409393 ConversionPatternRewriter &rewriter) const override {
410- LDBG (" matchAndRewrite: memref.store " << op);
394+ LDBG (" matchAndRewrite: memref.store (before) " << op);
411395
412396 auto memrefType = op.getMemRef ().getType ();
413397 if (auto memrefTy = dyn_cast<MemRefType>(memrefType)) {
@@ -473,10 +457,10 @@ struct MemRefStoreConverter : OpConversionPattern<memref::StoreOp> {
473457 rewriter.create <LLVM::GEPOp>(loc, ptrTy, ptrTy, basePtr, linearIndex);
474458
475459 // Store the value
476- rewriter.create <LLVM::StoreOp>(loc, adaptor.getValue (), elementPtr);
460+ auto storeOp =
461+ rewriter.create <LLVM::StoreOp>(loc, adaptor.getValue (), elementPtr);
477462 rewriter.eraseOp (op);
478-
479- LDBG (" matchAndRewrite: memref.store done -> LLVM GEP + store" );
463+ LDBG (" matchAndRewrite: memref.store (after) -> " << storeOp);
480464 return success ();
481465 }
482466
@@ -491,7 +475,7 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
491475 LogicalResult
492476 matchAndRewrite (memref::LoadOp op, OpAdaptor adaptor,
493477 ConversionPatternRewriter &rewriter) const override {
494- LDBG (" matchAndRewrite: memref.load " << op);
478+ LDBG (" matchAndRewrite: memref.load (before) " << op);
495479
496480 auto memrefType = op.getMemRef ().getType ();
497481 if (auto memrefTy = dyn_cast<MemRefType>(memrefType)) {
@@ -567,7 +551,7 @@ struct MemRefLoadConverter : OpConversionPattern<memref::LoadOp> {
567551 rewriter.create <LLVM::LoadOp>(loc, newResultType, elementPtr);
568552 rewriter.replaceOp (op, loadedValue);
569553
570- LDBG (" matchAndRewrite: memref.load done -> LLVM GEP + load " );
554+ LDBG (" matchAndRewrite: memref.load (after) -> " << loadedValue );
571555 return success ();
572556 }
573557
@@ -591,14 +575,19 @@ struct TypeOffsetConverter : OpConversionPattern<tptr::TypeOffsetOp> {
591575 LogicalResult
592576 matchAndRewrite (tptr::TypeOffsetOp op, OpAdaptor adaptor,
593577 ConversionPatternRewriter &rewriter) const override {
594- LDBG (" matchAndRewrite: type_offset " << op);
578+ LDBG (" matchAndRewrite: type_offset (before) " << op);
595579
596580 auto size = getTypeSize (op);
581+ if (size.isScalable ()) {
582+ return rewriter.notifyMatchFailure (op, " scalable type size unsupported" );
583+ }
584+ auto fixedSize = static_cast <int64_t >(size.getFixedValue ());
597585 auto constOp = rewriter.create <LLVM::ConstantOp>(
598- op.getLoc (), op.getType (), rewriter.getIntegerAttr (op.getType (), size));
586+ op.getLoc (), op.getType (),
587+ rewriter.getIntegerAttr (op.getType (), fixedSize));
599588
600589 rewriter.replaceOp (op, constOp);
601- LDBG (" matchAndRewrite: type_offset done -> " << size );
590+ LDBG (" matchAndRewrite: type_offset (after) -> " << constOp );
602591 return success ();
603592 }
604593};
@@ -613,4 +602,4 @@ void populateTPtrToLLVMConversionPatterns(RewritePatternSet &patterns,
613602}
614603
615604} // namespace tptr
616- } // namespace mlir
605+ } // namespace mlir
0 commit comments