|
33 | 33 | #include "mlir/IR/OpImplementation.h" |
34 | 34 | #include "mlir/IR/PatternMatch.h" |
35 | 35 | #include "mlir/IR/TypeUtilities.h" |
| 36 | +#include "mlir/IR/ValueRange.h" |
36 | 37 | #include "mlir/Interfaces/SubsetOpInterface.h" |
37 | 38 | #include "mlir/Interfaces/ValueBoundsOpInterface.h" |
38 | 39 | #include "mlir/Support/LLVM.h" |
@@ -2394,60 +2395,89 @@ static LogicalResult rewriteFromElementsAsSplat(FromElementsOp fromElementsOp, |
2394 | 2395 | /// |
2395 | 2396 | /// becomes |
2396 | 2397 | /// %2 = vector.shape_cast %source : vector<1x2xi8> to vector<2xi8> |
2397 | | -static LogicalResult |
2398 | | -rewriteFromElementsAsShapeCast(FromElementsOp fromElementsOp, |
2399 | | - PatternRewriter &rewriter) { |
| 2398 | +/// |
| 2399 | +/// The requirements for this to be valid are |
| 2400 | +/// i) all elements are extracted from the same vector (source), |
| 2401 | +/// ii) source and from_elements result have the same number of elements, |
| 2402 | +/// iii) the elements are extracted in ascending order. |
| 2403 | +/// |
| 2404 | +/// It might be possible to rewrite vector.from_elements as a single |
| 2405 | +/// vector.extract if (ii) is not satisifed, or in some cases as a |
| 2406 | +/// a single vector_extract_strided_slice if (ii) and (iii) are not satisfied, |
| 2407 | +/// this is left for future consideration. |
| 2408 | +class FromElementsToShapCast : public OpRewritePattern<FromElementsOp> { |
| 2409 | +public: |
| 2410 | + using OpRewritePattern::OpRewritePattern; |
2400 | 2411 |
|
2401 | | - // The common source of vector.extract operations (if one exists), as well |
2402 | | - // as its shape and rank. These are set in the first iteration of the loop |
2403 | | - // over the operands (elements) of `fromElementsOp`. |
2404 | | - Value source; |
2405 | | - ArrayRef<int64_t> shape; |
2406 | | - int64_t rank; |
| 2412 | + LogicalResult matchAndRewrite(FromElementsOp fromElements, |
| 2413 | + PatternRewriter &rewriter) const override { |
2407 | 2414 |
|
2408 | | - for (auto [index, element] : llvm::enumerate(fromElementsOp.getElements())) { |
| 2415 | + mlir::OperandRange elements = fromElements.getElements(); |
| 2416 | + assert(!elements.empty() && "must be at least 1 element"); |
2409 | 2417 |
|
2410 | | - // Check that the element is defined by an extract operation, and that |
2411 | | - // the extract is on the same vector as all preceding elements. |
2412 | | - auto extractOp = |
2413 | | - dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp()); |
2414 | | - if (!extractOp) |
2415 | | - return failure(); |
2416 | | - Value currentSource = extractOp.getVector(); |
2417 | | - if (index == 0) { |
2418 | | - source = currentSource; |
2419 | | - shape = extractOp.getSourceVectorType().getShape(); |
2420 | | - rank = shape.size(); |
2421 | | - } else if (currentSource != source) { |
2422 | | - return failure(); |
| 2418 | + Value firstElement = elements.front(); |
| 2419 | + ExtractOp extractOp = |
| 2420 | + dyn_cast_if_present<vector::ExtractOp>(firstElement.getDefiningOp()); |
| 2421 | + if (!extractOp) { |
| 2422 | + return rewriter.notifyMatchFailure( |
| 2423 | + fromElements, "first element not from vector.extract"); |
2423 | 2424 | } |
| 2425 | + VectorType sourceType = extractOp.getSourceVectorType(); |
| 2426 | + Value source = extractOp.getVector(); |
2424 | 2427 |
|
2425 | | - // Check that the (linearized) index of extraction is the same as the index |
2426 | | - // in the result of `fromElementsOp`. |
2427 | | - ArrayRef<int64_t> position = extractOp.getStaticPosition(); |
2428 | | - assert(position.size() == rank && |
2429 | | - "scalar extract must have full rank position"); |
2430 | | - int64_t stride{1}; |
2431 | | - int64_t offset{0}; |
2432 | | - for (auto [pos, size] : |
2433 | | - llvm::zip(llvm::reverse(position), llvm::reverse(shape))) { |
2434 | | - if (pos == ShapedType::kDynamic) |
2435 | | - return failure(); |
2436 | | - offset += pos * stride; |
2437 | | - stride *= size; |
| 2428 | + // Check condition (ii). |
| 2429 | + if (static_cast<size_t>(sourceType.getNumElements()) != elements.size()) { |
| 2430 | + return rewriter.notifyMatchFailure(fromElements, |
| 2431 | + "number of elements differ"); |
2438 | 2432 | } |
2439 | | - if (offset != index) |
2440 | | - return failure(); |
2441 | | - } |
2442 | 2433 |
|
2443 | | - rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElementsOp, |
2444 | | - fromElementsOp.getType(), source); |
2445 | | -} |
| 2434 | + for (auto [indexMinusOne, element] : |
| 2435 | + llvm::enumerate(elements.drop_front(1))) { |
| 2436 | + |
| 2437 | + extractOp = |
| 2438 | + dyn_cast_if_present<vector::ExtractOp>(element.getDefiningOp()); |
| 2439 | + if (!extractOp) { |
| 2440 | + return rewriter.notifyMatchFailure(fromElements, |
| 2441 | + "element not from vector.extract"); |
| 2442 | + } |
| 2443 | + Value currentSource = extractOp.getVector(); |
| 2444 | + // Check condition (i). |
| 2445 | + if (currentSource != source) { |
| 2446 | + return rewriter.notifyMatchFailure(fromElements, |
| 2447 | + "element from different vector"); |
| 2448 | + } |
| 2449 | + |
| 2450 | + ArrayRef<int64_t> position = extractOp.getStaticPosition(); |
| 2451 | + assert(position.size() == static_cast<size_t>(sourceType.getRank()) && |
| 2452 | + "scalar extract must have full rank position"); |
| 2453 | + int64_t stride{1}; |
| 2454 | + int64_t offset{0}; |
| 2455 | + for (auto [pos, size] : llvm::zip(llvm::reverse(position), |
| 2456 | + llvm::reverse(sourceType.getShape()))) { |
| 2457 | + if (pos == ShapedType::kDynamic) { |
| 2458 | + return rewriter.notifyMatchFailure( |
| 2459 | + fromElements, "elements not in ascending order (dynamic order)"); |
| 2460 | + } |
| 2461 | + offset += pos * stride; |
| 2462 | + stride *= size; |
| 2463 | + } |
| 2464 | + // Check condition (iii). |
| 2465 | + if (offset != static_cast<int64_t>(indexMinusOne + 1)) { |
| 2466 | + return rewriter.notifyMatchFailure( |
| 2467 | + fromElements, "elements not in ascending order (static order)"); |
| 2468 | + } |
| 2469 | + } |
| 2470 | + |
| 2471 | + rewriter.replaceOpWithNewOp<ShapeCastOp>(fromElements, |
| 2472 | + fromElements.getType(), source); |
| 2473 | + return success(); |
| 2474 | + } |
| 2475 | +}; |
2446 | 2476 |
|
2447 | 2477 | void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, |
2448 | 2478 | MLIRContext *context) { |
2449 | 2479 | results.add(rewriteFromElementsAsSplat); |
2450 | | - results.add(rewriteFromElementsAsShapeCast); |
| 2480 | + results.add<FromElementsToShapCast>(context); |
2451 | 2481 | } |
2452 | 2482 |
|
2453 | 2483 | //===----------------------------------------------------------------------===// |
|
0 commit comments