|
18 | 18 | #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" |
19 | 19 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
20 | 20 | #include "mlir/Dialect/MemRef/IR/MemRef.h" |
| 21 | +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" |
21 | 22 | #include "mlir/IR/BuiltinDialect.h" |
22 | 23 | #include "mlir/IR/BuiltinTypes.h" |
23 | 24 | #include "mlir/IR/ImplicitLocOpBuilder.h" |
@@ -397,26 +398,26 @@ static void populateFlattenMemRefsLegality(ConversionTarget &target) { |
397 | 398 | } |
398 | 399 |
|
399 | 400 | // Materializes a multidimensional memory to unidimensional memory by using a |
400 | | -// memref.subview operation. |
| 401 | +// memref.collapse_shape operation. |
401 | 402 | // TODO: This is also possible for dynamically shaped memories. |
402 | | -static Value materializeSubViewFlattening(OpBuilder &builder, MemRefType type, |
403 | | - ValueRange inputs, Location loc) { |
| 403 | +static Value materializeCollapseShapeFlattening(OpBuilder &builder, |
| 404 | + MemRefType type, |
| 405 | + ValueRange inputs, |
| 406 | + Location loc) { |
404 | 407 | assert(type.hasStaticShape() && |
405 | 408 | "Can only subview flatten memref's with static shape (for now...)."); |
406 | 409 | MemRefType sourceType = cast<MemRefType>(inputs[0].getType()); |
407 | 410 | int64_t memSize = sourceType.getNumElements(); |
408 | | - unsigned dims = sourceType.getShape().size(); |
| 411 | + ArrayRef<int64_t> sourceShape = sourceType.getShape(); |
| 412 | + ArrayRef<int64_t> targetShape = ArrayRef<int64_t>(memSize); |
409 | 413 |
|
410 | | - // Build offset, sizes and strides |
411 | | - SmallVector<OpFoldResult> sizes(dims, builder.getIndexAttr(0)); |
412 | | - SmallVector<OpFoldResult> offsets(dims, builder.getIndexAttr(1)); |
413 | | - offsets[offsets.size() - 1] = builder.getIndexAttr(memSize); |
414 | | - SmallVector<OpFoldResult> strides(dims, builder.getIndexAttr(1)); |
| 414 | + // Build ReassociationIndices to collapse completely to 1D MemRef. |
| 415 | + auto indices = getReassociationIndicesForCollapse(sourceShape, targetShape); |
| 416 | + assert(indices.has_value() && "expected a valid collapse"); |
415 | 417 |
|
416 | 418 | // Generate the appropriate return type: |
417 | | - MemRefType outType = MemRefType::get({memSize}, type.getElementType()); |
418 | | - return builder.create<memref::SubViewOp>(loc, outType, inputs[0], sizes, |
419 | | - offsets, strides); |
| 419 | + return builder.create<memref::CollapseShapeOp>(loc, inputs[0], |
| 420 | + indices.value()); |
420 | 421 | } |
421 | 422 |
|
422 | 423 | static void populateTypeConversionPatterns(TypeConverter &typeConverter) { |
@@ -489,7 +490,7 @@ struct FlattenMemRefCallsPass |
489 | 490 |
|
490 | 491 | // Add a target materializer to handle memory flattening through |
491 | 492 | // memref.subview operations. |
492 | | - typeConverter.addTargetMaterialization(materializeSubViewFlattening); |
| 493 | + typeConverter.addTargetMaterialization(materializeCollapseShapeFlattening); |
493 | 494 |
|
494 | 495 | if (applyPartialConversion(getOperation(), target, std::move(patterns)) |
495 | 496 | .failed()) { |
|
0 commit comments