Skip to content

Commit d5cda23

Browse files
committed
working
1 parent 79d2edf commit d5cda23

File tree

3 files changed

+156
-260
lines changed

3 files changed

+156
-260
lines changed

mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@
1414
#ifndef MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
1515
#define MLIR_DIALECT_MEMREF_TRANSFORMS_TRANSFORMS_H
1616

17+
#include "mlir/IR/OpDefinition.h"
1718
#include "mlir/Support/LLVM.h"
19+
#include "mlir/Support/LogicalResult.h"
20+
#include "llvm/ADT/ArrayRef.h"
21+
#include "llvm/ADT/SmallVector.h"
1822
#include "llvm/ADT/STLFunctionalExtras.h"
1923

2024
namespace mlir {
25+
class Location;
2126
class OpBuilder;
2227
class RewritePatternSet;
2328
class RewriterBase;
@@ -33,7 +38,9 @@ class NarrowTypeEmulationConverter;
3338
namespace memref {
3439
class AllocOp;
3540
class AllocaOp;
41+
class CollapseShapeOp;
3642
class DeallocOp;
43+
class ExpandShapeOp;
3744

3845
//===----------------------------------------------------------------------===//
3946
// Patterns
@@ -213,6 +220,98 @@ FailureOr<Value> replaceWithIndependentOp(RewriterBase &rewriter,
213220
memref::AllocaOp allocToAlloca(
214221
RewriterBase &rewriter, memref::AllocOp alloc,
215222
function_ref<bool(memref::AllocOp, memref::DeallocOp)> filter = nullptr);
223+
224+
/// Compute the expanded sizes of the given \p expandShape for the
225+
/// \p groupId-th reassociation group.
226+
/// \p origSizes hold the sizes of the source shape as values.
227+
/// This is used to compute the new sizes in cases of dynamic shapes.
228+
///
229+
/// sizes#i =
230+
/// baseSizes#groupId / product(expandShapeSizes#j,
231+
/// for j in group excluding reassIdx#i)
232+
/// Where reassIdx#i is the reassociation index at index i in \p groupId.
233+
///
234+
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
235+
///
236+
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
237+
/// this is not possible because this function uses the Affine dialect and the
238+
/// MemRef dialect cannot depend on the Affine dialect.
239+
SmallVector<OpFoldResult>
240+
getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
241+
ArrayRef<OpFoldResult> origSizes, unsigned groupId);
242+
243+
/// Compute the expanded strides of the given \p expandShape for the
244+
/// \p groupId-th reassociation group.
245+
/// \p origStrides and \p origSizes hold respectively the strides and sizes
246+
/// of the source shape as values.
247+
/// This is used to compute the strides in cases of dynamic shapes and/or
248+
/// dynamic stride for this reassociation group.
249+
///
250+
/// strides#i =
251+
/// origStrides#reassDim * product(expandShapeSizes#j, for j in
252+
/// reassIdx#i+1..reassIdx#i+group.size-1)
253+
///
254+
/// Where reassIdx#i is the reassociation index for at index i in \p groupId
255+
/// and expandShapeSizes#j is either:
256+
/// - The constant size at dimension j, derived directly from the result type of
257+
/// the expand_shape op, or
258+
/// - An affine expression: baseSizes#reassDim / product of all constant sizes
259+
/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
260+
/// element.)
261+
///
262+
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
263+
///
264+
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
265+
/// this is not possible because this function uses the Affine dialect and the
266+
/// MemRef dialect cannot depend on the Affine dialect.
267+
SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
268+
OpBuilder &builder,
269+
ArrayRef<OpFoldResult> origSizes,
270+
ArrayRef<OpFoldResult> origStrides,
271+
unsigned groupId);
272+
273+
/// Produce an OpFoldResult object with \p builder at \p loc representing
274+
/// `prod(valueOrConstant#i, for i in {indices})`,
275+
/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
276+
/// values[i] otherwise.
277+
///
278+
/// \pre for all index in indices: index < values.size()
279+
/// \pre for all index in indices: index < maybeConstants.size()
280+
OpFoldResult
281+
getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
282+
ArrayRef<int64_t> maybeConstants,
283+
ArrayRef<OpFoldResult> values,
284+
llvm::function_ref<bool(int64_t)> isDynamic);
285+
286+
/// Compute the collapsed size of the given \p collapseShape for the
287+
/// \p groupId-th reassociation group.
288+
/// \p origSizes hold the sizes of the source shape as values.
289+
/// This is used to compute the new sizes in cases of dynamic shapes.
290+
///
291+
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
292+
/// this is not possible because this function uses the Affine dialect and the
293+
/// MemRef dialect cannot depend on the Affine dialect.
294+
SmallVector<OpFoldResult>
295+
getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
296+
ArrayRef<OpFoldResult> origSizes, unsigned groupId);
297+
298+
/// Compute the collapsed stride of the given \p collpaseShape for the
299+
/// \p groupId-th reassociation group.
300+
/// \p origStrides and \p origSizes hold respectively the strides and sizes
301+
/// of the source shape as values.
302+
/// This is used to compute the strides in cases of dynamic shapes and/or
303+
/// dynamic stride for this reassociation group.
304+
///
305+
/// Conceptually this helper function returns the stride of the inner most
306+
/// dimension of that group in the original shape.
307+
///
308+
/// \post result.size() == 1, in other words, each group collapse to one
309+
/// dimension.
310+
SmallVector<OpFoldResult>
311+
getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
312+
ArrayRef<OpFoldResult> origSizes,
313+
ArrayRef<OpFoldResult> origStrides, unsigned groupId);
314+
216315
} // namespace memref
217316
} // namespace mlir
218317

mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp

Lines changed: 18 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2525
#include "llvm/ADT/STLExtras.h"
2626
#include "llvm/ADT/SmallBitVector.h"
27+
#include "mlir/IR/OpDefinition.h"
2728
#include <optional>
2829

2930
namespace mlir {
@@ -35,6 +36,7 @@ namespace memref {
3536

3637
using namespace mlir;
3738
using namespace mlir::affine;
39+
using namespace mlir::memref;
3840

3941
namespace {
4042

@@ -250,23 +252,12 @@ struct ExtractStridedMetadataOpSubviewFolder
250252
}
251253
};
252254

253-
/// Compute the expanded sizes of the given \p expandShape for the
254-
/// \p groupId-th reassociation group.
255-
/// \p origSizes hold the sizes of the source shape as values.
256-
/// This is used to compute the new sizes in cases of dynamic shapes.
257-
///
258-
/// sizes#i =
259-
/// baseSizes#groupId / product(expandShapeSizes#j,
260-
/// for j in group excluding reassIdx#i)
261-
/// Where reassIdx#i is the reassociation index at index i in \p groupId.
262-
///
263-
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
264-
///
265-
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
266-
/// this is not possible because this function uses the Affine dialect and the
267-
/// MemRef dialect cannot depend on the Affine dialect.
268-
static SmallVector<OpFoldResult>
269-
getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
255+
} // namespace
256+
257+
namespace mlir {
258+
namespace memref {
259+
SmallVector<OpFoldResult>
260+
getExpandedSizes(ExpandShapeOp expandShape, OpBuilder &builder,
270261
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
271262
SmallVector<int64_t, 2> reassocGroup =
272263
expandShape.getReassociationIndices()[groupId];
@@ -305,31 +296,7 @@ getExpandedSizes(memref::ExpandShapeOp expandShape, OpBuilder &builder,
305296
return expandedSizes;
306297
}
307298

308-
/// Compute the expanded strides of the given \p expandShape for the
309-
/// \p groupId-th reassociation group.
310-
/// \p origStrides and \p origSizes hold respectively the strides and sizes
311-
/// of the source shape as values.
312-
/// This is used to compute the strides in cases of dynamic shapes and/or
313-
/// dynamic stride for this reassociation group.
314-
///
315-
/// strides#i =
316-
/// origStrides#reassDim * product(expandShapeSizes#j, for j in
317-
/// reassIdx#i+1..reassIdx#i+group.size-1)
318-
///
319-
/// Where reassIdx#i is the reassociation index for at index i in \p groupId
320-
/// and expandShapeSizes#j is either:
321-
/// - The constant size at dimension j, derived directly from the result type of
322-
/// the expand_shape op, or
323-
/// - An affine expression: baseSizes#reassDim / product of all constant sizes
324-
/// in expandShapeSizes. (Remember expandShapeSizes has at most one dynamic
325-
/// element.)
326-
///
327-
/// \post result.size() == expandShape.getReassociationIndices()[groupId].size()
328-
///
329-
/// TODO: Move this utility function directly within ExpandShapeOp. For now,
330-
/// this is not possible because this function uses the Affine dialect and the
331-
/// MemRef dialect cannot depend on the Affine dialect.
332-
SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
299+
SmallVector<OpFoldResult> getExpandedStrides(ExpandShapeOp expandShape,
333300
OpBuilder &builder,
334301
ArrayRef<OpFoldResult> origSizes,
335302
ArrayRef<OpFoldResult> origStrides,
@@ -405,14 +372,7 @@ SmallVector<OpFoldResult> getExpandedStrides(memref::ExpandShapeOp expandShape,
405372
return expandedStrides;
406373
}
407374

408-
/// Produce an OpFoldResult object with \p builder at \p loc representing
409-
/// `prod(valueOrConstant#i, for i in {indices})`,
410-
/// where valueOrConstant#i is maybeConstant[i] when \p isDymamic is false,
411-
/// values[i] otherwise.
412-
///
413-
/// \pre for all index in indices: index < values.size()
414-
/// \pre for all index in indices: index < maybeConstants.size()
415-
static OpFoldResult
375+
OpFoldResult
416376
getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
417377
ArrayRef<int64_t> maybeConstants,
418378
ArrayRef<OpFoldResult> values,
@@ -450,8 +410,8 @@ getProductOfValues(ArrayRef<int64_t> indices, OpBuilder &builder, Location loc,
450410
/// TODO: Move this utility function directly within CollapseShapeOp. For now,
451411
/// this is not possible because this function uses the Affine dialect and the
452412
/// MemRef dialect cannot depend on the Affine dialect.
453-
static SmallVector<OpFoldResult>
454-
getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
413+
SmallVector<OpFoldResult>
414+
getCollapsedSize(CollapseShapeOp collapseShape, OpBuilder &builder,
455415
ArrayRef<OpFoldResult> origSizes, unsigned groupId) {
456416
SmallVector<OpFoldResult> collapsedSize;
457417

@@ -491,8 +451,8 @@ getCollapsedSize(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
491451
///
492452
/// \post result.size() == 1, in other words, each group collapse to one
493453
/// dimension.
494-
static SmallVector<OpFoldResult>
495-
getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
454+
SmallVector<OpFoldResult>
455+
getCollapsedStride(CollapseShapeOp collapseShape, OpBuilder &builder,
496456
ArrayRef<OpFoldResult> origSizes,
497457
ArrayRef<OpFoldResult> origStrides, unsigned groupId) {
498458
SmallVector<int64_t, 2> reassocGroup =
@@ -546,6 +506,10 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder,
546506

547507
return {lastValidStride};
548508
}
509+
} // namespace memref
510+
} // namespace mlir
511+
512+
namespace {
549513

550514
/// From `reshape_like(memref, subSizes, subStrides))` compute
551515
///

0 commit comments

Comments
 (0)