Skip to content

Commit b6eb2bc

Browse files
Groverksspstarkcdpr
authored andcommitted
[Codegen] Cleanup VectorLayoutAnalysis details (iree-org#22418)
This patch moves the implementation details of VectorLayoutAnalysis to it's implementation file. The class needed to be exposed in an initial implementation, but exposing these details isn't required anymore, and a simple function call is enough.
1 parent d57b9e5 commit b6eb2bc

File tree

9 files changed

+119
-215
lines changed

9 files changed

+119
-215
lines changed

compiler/src/iree/compiler/Codegen/Common/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,6 @@ iree_compiler_cc_library(
184184
"TileInferenceUtils.h",
185185
"Transforms.h",
186186
"UserConfig.h",
187-
"VectorLayoutAnalysis.h",
188187
],
189188
deps = [
190189
":PassHeaders",

compiler/src/iree/compiler/Codegen/Common/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ iree_cc_library(
6565
"TileInferenceUtils.h"
6666
"Transforms.h"
6767
"UserConfig.h"
68-
"VectorLayoutAnalysis.h"
6968
SRCS
7069
"AddFastMathFlags.cpp"
7170
"BlockDynamicDimensions.cpp"

compiler/src/iree/compiler/Codegen/Common/GPU/GPUNestedLayoutDistributionPatterns.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#include <cstdint>
88
#include "iree/compiler/Codegen/Common/GPU/GPUPatterns.h"
99
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
10-
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
1110
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
1211
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
1312
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"

compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.cpp

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

77
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
8-
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
8+
#include "iree/compiler/Codegen/Common/Transforms.h"
99
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
1010
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
1111
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -17,6 +17,8 @@
1717
#include "mlir/Support/LogicalResult.h"
1818
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1919

20+
#include <deque>
21+
2022
#define DEBUG_TYPE "iree-codegen-gpu-vector-distribution"
2123

2224
using namespace mlir::iree_compiler::IREE::VectorExt;
@@ -34,15 +36,16 @@ constexpr StringLiteral kVectorLayoutRedistributeAttrName =
3436
/// Set signature for the operation based on the analysis. Returns failure if
3537
/// an operation contains vectors that cannot be distributed i.e. they have no
3638
/// layout.
37-
LogicalResult setOpSignature(Operation *op, VectorLayoutAnalysis &analysis,
38-
const VectorLayoutOptions &options) {
39+
LogicalResult
40+
setOpSignature(Operation *op,
41+
const llvm::MapVector<Value, VectorLayoutInterface> &layouts,
42+
const VectorLayoutOptions &options) {
3943
SmallVector<Attribute> operands;
4044
SmallVector<Attribute> results;
4145

4246
for (Value operand : op->getOperands()) {
4347
if (auto vectorOperand = dyn_cast<VectorValue>(operand)) {
44-
if (auto layout =
45-
analysis.getLayout<VectorLayoutInterface>(vectorOperand)) {
48+
if (auto layout = layouts.lookup(vectorOperand)) {
4649
operands.push_back(layout);
4750
continue;
4851
}
@@ -57,8 +60,7 @@ LogicalResult setOpSignature(Operation *op, VectorLayoutAnalysis &analysis,
5760

5861
for (Value result : op->getResults()) {
5962
if (auto vectorResult = dyn_cast<VectorValue>(result)) {
60-
if (auto layout =
61-
analysis.getLayout<VectorLayoutInterface>(vectorResult)) {
63+
if (auto layout = layouts.lookup(vectorResult)) {
6264
results.push_back(layout);
6365
continue;
6466
}
@@ -356,17 +358,19 @@ LogicalResult distributeVectorOps(Operation *root,
356358
VectorLayoutOptions &options) {
357359
// Run the analysis and determine the layouts.
358360
LLVM_DEBUG(llvm::dbgs() << "Running Layout Analysis\n");
359-
VectorLayoutAnalysis analysis(root);
360-
if (failed(analysis.run()))
361+
llvm::MapVector<Value, VectorLayoutInterface> layouts;
362+
if (failed(propagateVectorLayoutInfo(root, layouts))) {
363+
LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Failed\n");
361364
return failure();
365+
}
362366
LLVM_DEBUG(llvm::dbgs() << "Layout Analysis Succeded\n");
363367
LLVM_DEBUG(llvm::dbgs() << "\n\n");
364368

365369
// Go to each operation, and set its distribution signature.
366370
LLVM_DEBUG(
367371
llvm::dbgs() << "Setting distribution signatures for operations\n");
368372
root->walk([&](Operation *op) {
369-
if (failed(setOpSignature(op, analysis, options))) {
373+
if (failed(setOpSignature(op, layouts, options))) {
370374
LLVM_DEBUG({
371375
llvm::dbgs() << "Skipping operation because not all vector "
372376
"operands/results have a layout:\n";

compiler/src/iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
#ifndef IREE_COMPILER_CODEGEN_COMMON_GPU_VECTOR_DISTRIBUTION_H_
88
#define IREE_COMPILER_CODEGEN_COMMON_GPU_VECTOR_DISTRIBUTION_H_
99

10-
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
1110
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
1211
#include "llvm/Support/Debug.h"
1312
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
@@ -16,6 +15,8 @@
1615

1716
namespace mlir::iree_compiler {
1817

18+
using IREE::VectorExt::VectorLayoutInterface;
19+
1920
/// A signature describing the layout for each value of vector type which is
2021
/// an operand or result of this operation.
2122
///

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include "iree/compiler/Codegen/Common/GPU/GPUVectorDistribution.h"
1313
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
1414
#include "iree/compiler/Codegen/Common/Transforms.h"
15-
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
1615
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1716
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
1817
#include "iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.h"

compiler/src/iree/compiler/Codegen/Common/Transforms.h

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,76 @@ FailureOr<IREETilingResult>
7777
tileDispatchUsingSCFForOp(RewriterBase &rewriter, TilingInterface op,
7878
linalg::LinalgTilingOptions options);
7979

80+
namespace IREE::VectorExt {
81+
class VectorLayoutInterface;
82+
} // namespace IREE::VectorExt
83+
84+
/// Analyzes the root op and it's nested ops to propagate vector layouts
85+
/// originating from to_vector operations. Example:
86+
///
87+
/// %root = vector.transfer_read
88+
/// |
89+
/// --> anchored to layout L (using a to_layout op)
90+
/// %root2 = vector.transfer_read
91+
/// %c = arith.mulf %root, %b
92+
/// |
93+
/// --> %root, %b and %c must have the same layout
94+
/// %e = arith.divf %b, %root2
95+
/// |
96+
/// --> %root2, %b and %e must have the same layout
97+
///
98+
/// Here, the user provided an anchor point for %root, fixing it's layout to L.
99+
/// The layout then uses it's inference rules to find the layout of other
100+
/// values:
101+
///
102+
/// %root = vector.transfer_read
103+
/// |
104+
/// --> infered to layout L
105+
/// %root2 = vector.transfer_read
106+
/// |
107+
/// --> infered to layout L
108+
/// %c = arith.mulf %root, %b
109+
/// |
110+
/// --> infered to layout L
111+
/// %e = arith.divf %b, %root2
112+
/// |
113+
/// --> infered to layout L
114+
///
115+
/// If at any point, a value has a layout, but the user of that value requires
116+
/// a different layout, the analysis inserts a resolution operation. This
117+
/// resolution operation is `iree_vector_ext.to_layout`.
118+
/// For Example:
119+
///
120+
/// %0 = vector.transfer_read
121+
/// |
122+
/// --> anchored to layout L
123+
/// %1 = vector.transfer_read
124+
/// |
125+
/// --> anchored to layout L'
126+
/// arith.addf %0, %1
127+
/// |
128+
/// --> %0 and %1 must have the same layout
129+
///
130+
/// To resolve the conflict, the analysis chooses one of the layouts, say
131+
/// L, and inserts a resolution operation to convert the other layout to L.
132+
///
133+
/// %0 = vector.transfer_read
134+
/// |
135+
/// --> anchored to layout L
136+
/// %1 = vector.transfer_read
137+
/// |
138+
/// --> anchored to layout L'
139+
/// %resolved = iree_vector_ext.to_layout %1
140+
/// |
141+
/// --> infered to layout L
142+
/// arith.addf %0, %resolved
143+
///
144+
/// The analysis itself will not try to resolve the conflict, but instead
145+
/// will leave it as a to_layout op, which can be rewritten by the caller.
146+
LogicalResult propagateVectorLayoutInfo(
147+
Operation *root,
148+
llvm::MapVector<Value, IREE::VectorExt::VectorLayoutInterface> &layouts);
149+
80150
/// Transform a `scf.for` loop with a strictly positive step
81151
/// for %i = %lb to %ub step %s
82152
/// into a 0-based loop with step 1

compiler/src/iree/compiler/Codegen/Common/VectorLayoutAnalysis.cpp

Lines changed: 33 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
#include "iree/compiler/Codegen/Common/VectorLayoutAnalysis.h"
87
#include "iree/compiler/Codegen/Common/Passes.h"
8+
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtDialect.h"
9+
#include "iree/compiler/Codegen/Dialect/VectorExt/IR/VectorExtOps.h"
910

1011
#include <cassert>
1112

@@ -249,7 +250,7 @@ ChangeResult DistributionLayout::resolveWithPossibleConflict(
249250
// Create a new value for the resolved value and subscribe it to propagation
250251
// and enforcement.
251252
// We possibly don't need to subscribe this since this value has already
252-
// reached the top of the lattice and shouldn't do anything else. But it's
253+
// reached the top of the lattice and shouldn't do anything else. But its
253254
// nicer to do it to have consistency.
254255
DistributionLayout *resolvedLayout =
255256
propagation->getLatticeElement(resolvedValue);
@@ -296,7 +297,7 @@ ChangeResult DistributionLayout::resolve(const VectorLayoutInterface &rhs,
296297
}
297298
}
298299

299-
// This return will never be reached, but it's here to make the compiler
300+
// This return will never be reached, but its here to make the compiler
300301
// happy.
301302
return ChangeResult::NoChange;
302303
}
@@ -698,7 +699,7 @@ static void enforceLayoutToMultiReductionOp(
698699
return;
699700
}
700701
// Reductions should always propagate value layout to result. Result can
701-
// enforce it's layout on init.
702+
// enforce its layout on init.
702703
const DistributionLayout *result = resultLattices[0];
703704
DistributionLayout *init = operandLattices[1];
704705

@@ -1352,78 +1353,55 @@ DistributionLayout *EnforceLayout::getLatticeElement(Value val) {
13521353
return layout;
13531354
}
13541355

1355-
/// ==========================================================================
1356-
/// VectorLayoutAnalysis
1357-
/// ==========================================================================
1356+
namespace mlir::iree_compiler {
13581357

1359-
LogicalResult VectorLayoutAnalysis::run() {
1358+
LogicalResult propagateVectorLayoutInfo(
1359+
Operation *root, llvm::MapVector<Value, VectorLayoutInterface> &layouts) {
1360+
DataFlowSolver solver;
13601361
// The order of loading matters here, because propagateLayout does anchoring
13611362
// initialization which needs the lattice to know both enforcement and
13621363
// propagation.
13631364
solver.load<PropagateLayout>(root->getContext());
13641365
solver.load<EnforceLayout>(root->getContext());
1365-
return solver.initializeAndRun(root);
1366-
}
1367-
1368-
VectorLayoutInterface VectorLayoutAnalysis::getLayout(Value val) {
1369-
const DistributionLayout *layout =
1370-
solver.lookupState<DistributionLayout>(val);
1371-
if (!layout) {
1372-
return VectorLayoutInterface();
1366+
if (failed(solver.initializeAndRun(root))) {
1367+
return failure();
13731368
}
1374-
return layout->getLayout();
1375-
}
1376-
1377-
void VectorLayoutAnalysis::debugAnnotateLayouts() {
1378-
// Annotate each operation with the layout of it's result.
1369+
// Iterate over all values and extract their layouts.
13791370
root->walk([&](Operation *op) {
1380-
if (op->getNumResults() == 0) {
1381-
return;
1371+
for (Value result : op->getResults()) {
1372+
const DistributionLayout *layout =
1373+
solver.lookupState<DistributionLayout>(result);
1374+
if (layout && layout->hasLayout()) {
1375+
layouts[result] = layout->getLayout();
1376+
}
13821377
}
13831378

1384-
for (auto [index, result] : llvm::enumerate(op->getResults())) {
1385-
if (!isa<VectorType>(result.getType())) {
1379+
for (Value operand : op->getOperands()) {
1380+
// Some operands may not have been visited as results (e.g., block
1381+
// arguments).
1382+
if (layouts.contains(operand)) {
13861383
continue;
13871384
}
1388-
1389-
// Do not annotate to_layout operations since they already have
1390-
// this information in their attributes.
1391-
if (isa<IREE::VectorExt::ToLayoutOp>(op)) {
1392-
continue;
1385+
const DistributionLayout *layout =
1386+
solver.lookupState<DistributionLayout>(operand);
1387+
if (layout && layout->hasLayout()) {
1388+
layouts[operand] = layout->getLayout();
13931389
}
1394-
1395-
Attribute layout = getLayout<Attribute>(result);
1396-
if (!layout) {
1397-
continue;
1398-
}
1399-
1400-
op->setAttr("layout_result_" + std::to_string(index), layout);
14011390
}
14021391
});
1392+
return success();
14031393
}
14041394

1405-
void VectorLayoutAnalysis::print(raw_ostream &os) {
1406-
debugAnnotateLayouts();
1407-
root->print(os);
1408-
}
1409-
1410-
void VectorLayoutAnalysis::dump() {
1411-
print(llvm::dbgs());
1412-
llvm::dbgs() << "\n";
1413-
}
1414-
1415-
namespace mlir::iree_compiler {
1416-
14171395
#define GEN_PASS_DEF_TESTVECTORLAYOUTANALYSISPASS
14181396
#include "iree/compiler/Codegen/Common/Passes.h.inc"
14191397

14201398
struct TestVectorLayoutAnalysisPass final
14211399
: impl::TestVectorLayoutAnalysisPassBase<TestVectorLayoutAnalysisPass> {
14221400
void runOnOperation() override {
14231401
Operation *root = getOperation();
1424-
VectorLayoutAnalysis analysis(getOperation());
1425-
if (failed(analysis.run())) {
1426-
root->emitError("layout analysis failed");
1402+
llvm::MapVector<Value, VectorLayoutInterface> layouts;
1403+
if (failed(propagateVectorLayoutInfo(root, layouts))) {
1404+
root->emitError("Layout Analysis Failed");
14271405
return signalPassFailure();
14281406
}
14291407

@@ -1433,14 +1411,9 @@ struct TestVectorLayoutAnalysisPass final
14331411
}
14341412

14351413
for (OpResult result : op->getOpResults()) {
1436-
if (auto layout = analysis.getLayout<Attribute>(result)) {
1437-
// Print layout attr to a string.
1438-
std::string layoutStr;
1439-
llvm::raw_string_ostream s(layoutStr);
1440-
s << layout;
1441-
// Emit remark.
1442-
op->emitRemark("layout of result #" +
1443-
Twine(result.getResultNumber()) + " is " + s.str());
1414+
if (layouts.contains(result)) {
1415+
op->emitRemark("layout of result #")
1416+
<< result.getResultNumber() << " is " << layouts[result];
14441417
}
14451418
}
14461419
});

0 commit comments

Comments
 (0)