|
45 | 45 | #include "mlir/IR/ValueRange.h" |
46 | 46 | #include "mlir/Pass/Pass.h" |
47 | 47 | #include "mlir/Pass/PassManager.h" |
| 48 | +#include "mlir/Support/LLVM.h" |
48 | 49 | #include "mlir/Support/LogicalResult.h" |
49 | 50 | #include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h" |
50 | 51 | #include "mlir/Target/LLVMIR/Dialect/LLVMIR/LLVMToLLVMIRTranslation.h" |
@@ -118,6 +119,17 @@ mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics, |
118 | 119 | return convertToReferenceType(shape, elementType); |
119 | 120 | } |
120 | 121 |
|
| 122 | +// Compute the identity stride for the default layout of a memref |
| 123 | +static llvm::SmallVector<std::int64_t> identityStrides(mlir::MemRefType t) { |
| 124 | + llvm::SmallVector<std::int64_t> strides(t.getShape().size()); |
| 125 | + if (!strides.empty()) |
| 126 | + strides.back() = 1; |
| 127 | + // To replace by range algorithms with an exclusive scan... |
| 128 | + for (auto i = strides.size(); i > 1; --i) |
| 129 | + strides[i - 2] = t.getShape()[i - 1] * strides[i - 1]; |
| 130 | + return strides; |
| 131 | +} |
| 132 | + |
121 | 133 | class CIRReturnLowering : public mlir::OpConversionPattern<cir::ReturnOp> { |
122 | 134 | public: |
123 | 135 | using OpConversionPattern<cir::ReturnOp>::OpConversionPattern; |
@@ -171,7 +183,7 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> { |
171 | 183 | }; |
172 | 184 |
|
173 | 185 | /// Emits the value from memory as expected by its users. Should be called when |
174 | | -/// the memory represetnation of a CIR type is not equal to its scalar |
| 186 | +/// the memory representation of a CIR type is not equal to its scalar |
175 | 187 | /// representation. |
176 | 188 | static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter, |
177 | 189 | cir::LoadOp op, mlir::Value value) { |
@@ -1176,7 +1188,8 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> { |
1176 | 1188 | case CIR::array_to_ptrdecay: { |
1177 | 1189 | auto newDstType = mlir::cast<mlir::MemRefType>(convertTy(dstType)); |
1178 | 1190 | rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>( |
1179 | | - op, newDstType, src, 0, std::nullopt, std::nullopt); |
| 1191 | + op, newDstType, src, 0, newDstType.getShape(), |
| 1192 | + identityStrides(newDstType)); |
1180 | 1193 | return mlir::success(); |
1181 | 1194 | } |
1182 | 1195 | case CIR::bitcast: { |
@@ -1334,7 +1347,7 @@ class CIRPtrStrideOpLowering |
1334 | 1347 | // memref.reinterpret_cast (%base, %stride) |
1335 | 1348 | // |
1336 | 1349 | // MemRef Dialect doesn't have GEP-like operation. memref.reinterpret_cast |
1337 | | - // only been used to propogate %base and %stride to memref.load/store and |
| 1350 | + // only been used to propagate %base and %stride to memref.load/store and |
1338 | 1351 | // should be erased after the conversion. |
1339 | 1352 | mlir::LogicalResult |
1340 | 1353 | matchAndRewrite(cir::PtrStrideOp op, OpAdaptor adaptor, |
|
0 commit comments