|
20 | 20 | #include "mlir/IR/AffineMap.h" |
21 | 21 | #include "mlir/IR/BuiltinAttributes.h" |
22 | 22 | #include "mlir/IR/Location.h" |
| 23 | +#include "mlir/IR/PatternMatch.h" |
23 | 24 | #include "mlir/IR/TypeRange.h" |
| 25 | +#include "mlir/Interfaces/DestinationStyleOpInterface.h" |
24 | 26 | #include "mlir/Support/LLVM.h" |
25 | 27 |
|
26 | 28 | // Pull in all enum type definitions and utility function declarations. |
@@ -158,6 +160,76 @@ Operation *cloneWithoutRegions(OpBuilder &b, Operation *op, |
158 | 160 | SmallVector<NamedAttribute> |
159 | 161 | getPrunedAttributeList(Operation *op, ArrayRef<StringRef> elidedAttrs); |
160 | 162 |
|
| 163 | +/// Folds cast-like operations into a consuming DestinationStyleOpInterface op |
| 164 | +/// if `isPreservingCast` is true. If the cast appears on a 'DPS-init operand', |
| 165 | +/// then the tied result type is updated as well to the type of the cast source, |
| 166 | +/// and a new cast must be inserted on the new op's result. `createCast` is used |
| 167 | +/// to build such required cast ops. |
| 168 | +/// |
| 169 | +/// ### Example |
| 170 | +/// If the `isPreservingCast` returns true if the cast is a "generalizing" |
| 171 | +/// `tensor.cast`, then this function would be have as follows: |
| 172 | +/// |
| 173 | +/// ```mlir |
| 174 | +/// %1 = tensor.cast %0 : tensor<8x16xf32> to tensor<?x?xf32> |
| 175 | +/// %2 = dps_op %1 ... : tensor<?x?xf32> ... |
| 176 | +/// ``` |
| 177 | +/// |
| 178 | +/// folds into: |
| 179 | +/// |
| 180 | +/// ```mlir |
| 181 | +/// %2 = dps_op %0 ... : tensor<8x16xf32> ... |
| 182 | +/// ``` |
| 183 | +LogicalResult foldCastProducers( |
| 184 | + RewriterBase &rewriter, DestinationStyleOpInterface consumerOp, |
| 185 | + llvm::function_ref<bool(Operation *)> isPreservingCast, |
| 186 | + llvm::function_ref<Value(RewriterBase &rewriter, Type originalType, |
| 187 | + Value replacement)> |
| 188 | + createCast); |
| 189 | + |
| 190 | +/// Folds `tensor.cast` ops into a consuming DestinationStyleOpInterface op |
| 191 | +/// if the casts make their operands less static. See also isPreservingCast |
| 192 | +/// above. |
| 193 | +template <typename CastOpType> |
| 194 | +LogicalResult foldCastProducers(DestinationStyleOpInterface op, |
| 195 | + RewriterBase &rewriter) { |
| 196 | + return foldCastProducers( |
| 197 | + rewriter, op, |
| 198 | + [](Operation *castOp) -> bool { |
| 199 | + auto concreteCast = dyn_cast<CastOpType>(castOp); |
| 200 | + if (!concreteCast) |
| 201 | + return false; |
| 202 | + RankedTensorType resultType = |
| 203 | + dyn_cast<RankedTensorType>(concreteCast.getType()); |
| 204 | + RankedTensorType sourceType = |
| 205 | + dyn_cast<RankedTensorType>(concreteCast->getOperand(0).getType()); |
| 206 | + if (!resultType || !sourceType) |
| 207 | + return false; |
| 208 | + return resultType.isGeneralizationOf(sourceType); |
| 209 | + }, |
| 210 | + [](RewriterBase &rewriter, Type resultType, Value operand) -> Value { |
| 211 | + return rewriter.create<CastOpType>(operand.getLoc(), resultType, |
| 212 | + operand); |
| 213 | + }); |
| 214 | +} |
| 215 | + |
| 216 | +/// A generic pattern for an Operation type that implements |
| 217 | +/// DestinationStyleOpInterface, allowing for absorbing cast-like operations |
| 218 | +/// that are producers of operands. |
| 219 | +template <typename OpType, typename CastOpType> |
| 220 | +struct FoldTensorCastIntoConsumerPattern : public OpRewritePattern<OpType> { |
| 221 | + using OpRewritePattern<OpType>::OpRewritePattern; |
| 222 | + |
| 223 | + LogicalResult matchAndRewrite(OpType op, |
| 224 | + PatternRewriter &rewriter) const override { |
| 225 | + DestinationStyleOpInterface dpsOp = |
| 226 | + llvm::dyn_cast<DestinationStyleOpInterface>(op.getOperation()); |
| 227 | + if (!dpsOp) |
| 228 | + return failure(); |
| 229 | + return foldCastProducers<CastOpType>(dpsOp, rewriter); |
| 230 | + } |
| 231 | +}; |
| 232 | + |
161 | 233 | } // namespace mlir |
162 | 234 |
|
163 | 235 | #endif // MLIR_DIALECT_UTILS_STRUCTUREDOPSUTILS_H |
0 commit comments