@@ -67,6 +67,43 @@ getStaticOrReifiedInputDims(OpBuilder &builder, Location loc, Value input,
6767 return success ();
6868}
6969
70+ // / Returns the rank of the Value's type, and 0 if it is not a ShapedType.
71+ static int64_t getRank (Value v) {
72+ auto type = dyn_cast<ShapedType>(v.getType ());
73+ if (type) {
74+ return type.getRank ();
75+ }
76+ return 0 ;
77+ }
78+
79+ // / Method similar to `LinalgOp`s that concatenates shapes of all operands.
80+ static SmallVector<OpFoldResult>
81+ createFlatListOfOperandDims (OpBuilder &b, Location loc, Operation *op) {
82+ SmallVector<OpFoldResult> res;
83+ for (OpOperand &opOperand : op->getOpOperands ()) {
84+ for (auto dim : llvm::seq (getRank (opOperand.get ()))) {
85+ res.push_back (linalg::createFoldedDimOp (b, loc, opOperand.get (), dim));
86+ }
87+ }
88+ return res;
89+ }
90+
91+ // / Permutes the offset and size arrays by the result indexes of the provided
92+ // / affine map.
93+ static SmallVector<Range> getPermutedRange (AffineMap permutation,
94+ ArrayRef<OpFoldResult> offsets,
95+ ArrayRef<OpFoldResult> sizes) {
96+ auto one = IntegerAttr::get (IndexType::get (permutation.getContext ()), 1 );
97+ assert (permutation.isProjectedPermutation () &&
98+ " Affine map should be a projected permutation" );
99+ SmallVector<Range> output;
100+ for (AffineExpr dimExpr : permutation.getResults ()) {
101+ int dim = cast<AffineDimExpr>(dimExpr).getPosition ();
102+ output.push_back (Range{offsets[dim], sizes[dim], one});
103+ }
104+ return output;
105+ }
106+
70107// ===----------------------------------------------------------------------===//
71108// ScatterOp
72109// ===----------------------------------------------------------------------===//
@@ -1891,6 +1928,101 @@ SmallVector<Range> UnPackOp::getIterationDomain(OpBuilder &builder) {
18911928 return LinalgExt::getIterationDomain (*this , builder);
18921929}
18931930
1931+ // ===----------------------------------------------------------------------===//
1932+ // ExpReductionOp
1933+ // ===----------------------------------------------------------------------===//
1934+
1935+ SmallVector<utils::IteratorType> ExpReductionOp::getLoopIteratorTypes () {
1936+ return llvm::to_vector (getIteratorTypes ()
1937+ .getAsValueRange <IREE::LinalgExt::IteratorTypeAttr,
1938+ utils::IteratorType>());
1939+ }
1940+
1941+ SmallVector<Range> ExpReductionOp::getIterationDomain (OpBuilder &b) {
1942+ Location loc = getLoc ();
1943+ OpFoldResult zero = b.getIndexAttr (0 );
1944+ OpFoldResult one = b.getIndexAttr (1 );
1945+
1946+ SmallVector<OpFoldResult> allShapesSizes =
1947+ createFlatListOfOperandDims (b, loc, getOperation ());
1948+ AffineMap map = getShapesToLoopsMap ();
1949+ return llvm::map_to_vector (map.getResults (), [&](AffineExpr loopExpr) {
1950+ OpFoldResult ofr =
1951+ affine::makeComposedFoldedAffineApply (b, loc, loopExpr, allShapesSizes);
1952+ return Range{zero, ofr, one};
1953+ });
1954+ }
1955+
1956+ FailureOr<TilingResult>
1957+ ExpReductionOp::getTiledImplementation (OpBuilder &b,
1958+ ArrayRef<OpFoldResult> offsets,
1959+ ArrayRef<OpFoldResult> sizes) {
1960+ Location loc = getLoc ();
1961+ auto indexingMapOp = cast<IndexingMapOpInterface>(getOperation ());
1962+ SmallVector<Value> tiledOperands;
1963+ SmallVector<Operation *> generatedSlices;
1964+ for (OpOperand &opOperand : getOperation ()->getOpOperands ()) {
1965+ AffineMap map = indexingMapOp.getMatchingIndexingMap (&opOperand);
1966+ SmallVector<Range> slice = getPermutedRange (map, offsets, sizes);
1967+ Operation *sliceOp = getSlice (b, loc, opOperand.get (), slice);
1968+ tiledOperands.emplace_back (sliceOp->getResult (0 ));
1969+ generatedSlices.push_back (sliceOp);
1970+ }
1971+
1972+ SmallVector<Type, 4 > resultTensorTypes;
1973+ if (getNumResults ()) {
1974+ resultTensorTypes = llvm::map_to_vector<4 >(
1975+ getDpsInitsMutable (), [&generatedSlices](OpOperand &opOperand) {
1976+ return generatedSlices[opOperand.getOperandNumber ()]
1977+ ->getResultTypes ()[0 ];
1978+ });
1979+ }
1980+
1981+ Operation *tiledOp = mlir::clone (b, *this , resultTensorTypes, tiledOperands);
1982+ return TilingResult{
1983+ {tiledOp}, SmallVector<Value>(tiledOp->getResults ()), generatedSlices};
1984+ }
1985+
1986+ LogicalResult ExpReductionOp::getResultTilePosition (
1987+ OpBuilder &b, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
1988+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
1989+ SmallVector<OpFoldResult> &resultSizes) {
1990+ auto indexingMapOp = cast<IndexingMapOpInterface>(getOperation ());
1991+ OpOperand *outOperand = getDpsInitOperand (resultNumber);
1992+ AffineMap indexingMap = indexingMapOp.getMatchingIndexingMap (outOperand);
1993+ SmallVector<Range> range = getPermutedRange (indexingMap, offsets, sizes);
1994+ resultOffsets.resize (range.size ());
1995+ resultSizes.resize (range.size ());
1996+ for (auto [index, r] : llvm::enumerate (range)) {
1997+ resultOffsets[index] = r.offset ;
1998+ resultSizes[index] = r.size ;
1999+ }
2000+ return success ();
2001+ }
2002+
2003+ FailureOr<TilingResult>
2004+ ExpReductionOp::generateResultTileValue (OpBuilder &b, unsigned resultNumber,
2005+ ArrayRef<OpFoldResult> offsets,
2006+ ArrayRef<OpFoldResult> sizes) {
2007+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
2008+ if (failed (getIterationDomainTileFromResultTile (
2009+ b, resultNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
2010+ return failure ();
2011+ }
2012+ FailureOr<TilingResult> tilingResult =
2013+ getTiledImplementation (b, mappedOffsets, mappedSizes);
2014+ if (failed (tilingResult)) {
2015+ return failure ();
2016+ }
2017+ if (tilingResult->tiledOps .size () != 1 ) {
2018+ return emitOpError (" failed to generate tiled implementation" );
2019+ }
2020+ return TilingResult{
2021+ tilingResult->tiledOps ,
2022+ SmallVector<Value>{tilingResult->tiledValues [resultNumber]},
2023+ tilingResult->generatedSlices };
2024+ }
2025+
18942026// ===----------------------------------------------------------------------===//
18952027// Im2colOp
18962028// ===----------------------------------------------------------------------===//
@@ -2449,24 +2581,6 @@ getAttentionIteratorTypes(int64_t domainRank, AffineMap qMap, AffineMap kMap,
24492581 return iteratorTypes;
24502582}
24512583
2452- static SmallVector<Range> getPermutedRange (AffineMap permutation,
2453- ArrayRef<OpFoldResult> offsets,
2454- ArrayRef<OpFoldResult> sizes) {
2455- auto one = IntegerAttr::get (IndexType::get (permutation.getContext ()), 1 );
2456- assert (permutation.isProjectedPermutation () &&
2457- " Indexing map should be a projected permutation" );
2458- SmallVector<Range> output;
2459- for (AffineExpr dimExpr : permutation.getResults ()) {
2460- int dim = cast<AffineDimExpr>(dimExpr).getPosition ();
2461- Range dimRange;
2462- dimRange.offset = offsets[dim];
2463- dimRange.size = sizes[dim];
2464- dimRange.stride = one;
2465- output.push_back (dimRange);
2466- }
2467- return output;
2468- }
2469-
24702584static Operation *getPermutedSlice (OpBuilder &b, Location loc, Value val,
24712585 AffineMap permutation,
24722586 ArrayRef<OpFoldResult> offsets,
@@ -3088,19 +3202,6 @@ SmallVector<utils::IteratorType> CustomOp::getLoopIteratorTypes() {
30883202 });
30893203}
30903204
3091- // / Method similar to `LinalgOp`s that concatenates shapes of all operands.
3092- static SmallVector<OpFoldResult>
3093- createFlatListOfOperandDims (OpBuilder &builder, Location loc,
3094- CustomOp customOp) {
3095- SmallVector<OpFoldResult> result;
3096- for (Value operand : customOp->getOperands ()) {
3097- for (auto dim : llvm::seq<unsigned >(customOp.getRank (operand))) {
3098- result.push_back (getDim (builder, loc, operand, dim));
3099- }
3100- }
3101- return result;
3102- }
3103-
31043205SmallVector<Range> CustomOp::getIterationDomainForDimensions (
31053206 OpBuilder &builder, ArrayRef<unsigned > dims, ArrayRef<unsigned > symbols) {
31063207 CustomOp customOp = *this ;
0 commit comments