|
14 | 14 | #include "mlir/Dialect/Affine/IR/AffineOps.h"
|
15 | 15 | #include "mlir/Dialect/Arith/IR/Arith.h"
|
16 | 16 | #include "mlir/IR/Value.h"
|
| 17 | +#include "llvm/ADT/DenseMap.h" |
17 | 18 |
|
18 | 19 | #include <numeric>
|
19 | 20 |
|
@@ -57,26 +58,29 @@ WarpDistributionPattern::moveRegionToNewWarpOpAndAppendReturns(
|
57 | 58 | warpOp.getResultTypes().end());
|
58 | 59 | auto yield = cast<gpu::YieldOp>(
|
59 | 60 | warpOp.getBodyRegion().getBlocks().begin()->getTerminator());
|
60 |
| - llvm::SmallSetVector<Value, 32> yieldValues(yield.getOperands().begin(), |
61 |
| - yield.getOperands().end()); |
| 61 | + SmallVector<Value> yieldValues(yield.getOperands().begin(), |
| 62 | + yield.getOperands().end()); |
| 63 | + llvm::SmallDenseMap<Value, unsigned> indexLookup; |
| 64 | + // Record the value -> first index mapping for faster lookup. |
| 65 | + for (auto [i, v] : llvm::enumerate(yieldValues)) { |
| 66 | + if (!indexLookup.count(v)) |
| 67 | + indexLookup[v] = i; |
| 68 | + } |
| 69 | + |
62 | 70 | for (auto [value, type] : llvm::zip_equal(newYieldedValues, newReturnTypes)) {
|
63 |
| - if (yieldValues.insert(value)) { |
| 71 | + // If the value already exists in the yield, don't create a new output. |
| 72 | + if (indexLookup.count(value)) { |
| 73 | + indices.push_back(indexLookup[value]); |
| 74 | + } else { |
| 75 | + // If the value is new, add it to the yield and to the types. |
| 76 | + yieldValues.push_back(value); |
64 | 77 | types.push_back(type);
|
65 | 78 | indices.push_back(yieldValues.size() - 1);
|
66 |
| - } else { |
67 |
| - // If the value already exit the region don't create a new output. |
68 |
| - for (auto [idx, yieldOperand] : |
69 |
| - llvm::enumerate(yieldValues.getArrayRef())) { |
70 |
| - if (yieldOperand == value) { |
71 |
| - indices.push_back(idx); |
72 |
| - break; |
73 |
| - } |
74 |
| - } |
75 | 79 | }
|
76 | 80 | }
|
77 |
| - yieldValues.insert_range(newYieldedValues); |
| 81 | + |
78 | 82 | WarpExecuteOnLane0Op newWarpOp = moveRegionToNewWarpOpAndReplaceReturns(
|
79 |
| - rewriter, warpOp, yieldValues.getArrayRef(), types); |
| 83 | + rewriter, warpOp, yieldValues, types); |
80 | 84 | rewriter.replaceOp(warpOp,
|
81 | 85 | newWarpOp.getResults().take_front(warpOp.getNumResults()));
|
82 | 86 | return newWarpOp;
|
|
0 commit comments