@@ -62,49 +62,8 @@ Value ConvertToLLVMPattern::getStridedElementPtr(
6262 ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
6363 Value memRefDesc, ValueRange indices,
6464 LLVM::GEPNoWrapFlags noWrapFlags) const {
65-
66- auto [strides, offset] = type.getStridesAndOffset ();
67-
68- MemRefDescriptor memRefDescriptor (memRefDesc);
69- // Use a canonical representation of the start address so that later
70- // optimizations have a longer sequence of instructions to CSE.
71- // If we don't do that we would sprinkle the memref.offset in various
72- // position of the different address computations.
73- Value base =
74- memRefDescriptor.bufferPtr (rewriter, loc, *getTypeConverter (), type);
75-
76- LLVM::IntegerOverflowFlags intOverflowFlags =
77- LLVM::IntegerOverflowFlags::none;
78- if (LLVM::bitEnumContainsAny (noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
79- intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
80- }
81- if (LLVM::bitEnumContainsAny (noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
82- intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
83- }
84-
85- Type indexType = getIndexType ();
86- Value index;
87- for (int i = 0 , e = indices.size (); i < e; ++i) {
88- Value increment = indices[i];
89- if (strides[i] != 1 ) { // Skip if stride is 1.
90- Value stride =
91- ShapedType::isDynamic (strides[i])
92- ? memRefDescriptor.stride (rewriter, loc, i)
93- : createIndexAttrConstant (rewriter, loc, indexType, strides[i]);
94- increment = rewriter.create <LLVM::MulOp>(loc, increment, stride,
95- intOverflowFlags);
96- }
97- index = index ? rewriter.create <LLVM::AddOp>(loc, index, increment,
98- intOverflowFlags)
99- : increment;
100- }
101-
102- Type elementPtrType = memRefDescriptor.getElementPtrType ();
103- return index ? rewriter.create <LLVM::GEPOp>(
104- loc, elementPtrType,
105- getTypeConverter ()->convertType (type.getElementType ()),
106- base, index, noWrapFlags)
107- : base;
65+ return LLVM::getStridedElementPtr (rewriter, loc, *getTypeConverter (), type,
66+ memRefDesc, indices, noWrapFlags);
10867}
10968
11069// Check if the MemRefType `type` is supported by the lowering. We currently
@@ -524,3 +483,52 @@ Value mlir::LLVM::composeValue(OpBuilder &builder, Location loc, ValueRange src,
524483
525484 return res;
526485}
486+
487+ Value mlir::LLVM::getStridedElementPtr (OpBuilder &builder, Location loc,
488+ const LLVMTypeConverter &converter,
489+ MemRefType type, Value memRefDesc,
490+ ValueRange indices,
491+ LLVM::GEPNoWrapFlags noWrapFlags) {
492+ auto [strides, offset] = type.getStridesAndOffset ();
493+
494+ MemRefDescriptor memRefDescriptor (memRefDesc);
495+ // Use a canonical representation of the start address so that later
496+ // optimizations have a longer sequence of instructions to CSE.
497+ // If we don't do that we would sprinkle the memref.offset in various
498+ // position of the different address computations.
499+ Value base = memRefDescriptor.bufferPtr (builder, loc, converter, type);
500+
501+ LLVM::IntegerOverflowFlags intOverflowFlags =
502+ LLVM::IntegerOverflowFlags::none;
503+ if (LLVM::bitEnumContainsAny (noWrapFlags, LLVM::GEPNoWrapFlags::nusw)) {
504+ intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nsw;
505+ }
506+ if (LLVM::bitEnumContainsAny (noWrapFlags, LLVM::GEPNoWrapFlags::nuw)) {
507+ intOverflowFlags = intOverflowFlags | LLVM::IntegerOverflowFlags::nuw;
508+ }
509+
510+ Type indexType = converter.getIndexType ();
511+ Value index;
512+ for (int i = 0 , e = indices.size (); i < e; ++i) {
513+ Value increment = indices[i];
514+ if (strides[i] != 1 ) { // Skip if stride is 1.
515+ Value stride =
516+ ShapedType::isDynamic (strides[i])
517+ ? memRefDescriptor.stride (builder, loc, i)
518+ : builder.create <LLVM::ConstantOp>(
519+ loc, indexType, builder.getIndexAttr (strides[i]));
520+ increment =
521+ builder.create <LLVM::MulOp>(loc, increment, stride, intOverflowFlags);
522+ }
523+ index = index ? builder.create <LLVM::AddOp>(loc, index, increment,
524+ intOverflowFlags)
525+ : increment;
526+ }
527+
528+ Type elementPtrType = memRefDescriptor.getElementPtrType ();
529+ return index ? builder.create <LLVM::GEPOp>(
530+ loc, elementPtrType,
531+ converter.convertType (type.getElementType ()), base, index,
532+ noWrapFlags)
533+ : base;
534+ }
0 commit comments