Skip to content

Commit 8889f74

Browse files
hanhanWGroverkss
authored andcommitted
[DT][NFC] Refactor encoding utilities. (1/n) (iree-org#19310)
The revision shuffles the utilities to the Encoding dialect and the Codegen dialect: 1. Move TileMxNxK struct and getEncodingInfoForMatmul method to the Codegen dialect (i.e., `Dialect/Codegen/*`) 2. Move isNarrowNResult to the Encoding dialect because it does not depend on any other dialects other than the Encoding dialect. 3. Move lowerContractionOpWithEncoding to Codegen dialect utils for the preparation. All the materialization logic will be moved to Codegen dialect; they share the utilities during the transition period. To accomplish (3), the revision introduces ResolveEncodingInfoFn function type, which decouple the dependency from MaterializeEncodingTypeConvert. It is a requirement because the type converter uses HAL while we don't want the Codegen dialect depending on HAL. We do not need the dependency once we move all the logic to attribute implementation. Minor cleanups: - Remove the `rank` argument from getEncodingInfoForMatmul. It is not used at all. - Add the `static` keyword to the local `getExpandedType` function. Note that the `lowerSetEncodingOpToPackOp` and `lowerUnsetEncodingToUnpackOp` functions are not moved because it requires more changes. They will be moved in a separate patch. --------- Signed-off-by: hanhanW <[email protected]>
1 parent 70534da commit 8889f74

File tree

12 files changed

+271
-234
lines changed

12 files changed

+271
-234
lines changed

compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodings.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
namespace mlir::iree_compiler {
3232

3333
using IREE::Codegen::MaterializeEncodingInfo;
34+
using IREE::Codegen::TileMxNxK;
3435

3536
#define GEN_PASS_DEF_CPUMATERIALIZEDEVICEENCODINGPASS
3637
#define GEN_PASS_DEF_CPUMATERIALIZEHOSTENCODINGPASS
@@ -445,8 +446,7 @@ materializeEncodingForTarget(RankedTensorType tensorType,
445446

446447
// Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
447448
// based on its operand index in the matmul.
448-
auto rank = tensorType.getRank();
449-
return getEncodingInfoForMatmul(encoding, rank, chosenTileMxNxK);
449+
return IREE::Codegen::getEncodingInfoForMatmul(encoding, chosenTileMxNxK);
450450
}
451451

452452
static FailureOr<MaterializeEncodingValueInfo>

compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -153,52 +153,4 @@ RankedTensorType dropEncoding(RankedTensorType type) {
153153
return RankedTensorType::get(type.getShape(), type.getElementType());
154154
}
155155

156-
MaterializeEncodingInfo getEncodingInfoForMatmul(EncodingAttr encoding,
157-
int64_t rank,
158-
TileMxNxK tileMxNxK) {
159-
MaterializeEncodingInfo encodingInfo;
160-
auto cDims = getEncodingContractionDims(encoding);
161-
// The following expects M, N, K, and Batch sizes of at most 1 for now
162-
assert(cDims->m.size() <= 1 && cDims->n.size() <= 1 && cDims->k.size() == 1 &&
163-
cDims->batch.size() <= 1 &&
164-
"Expected at most one M, N, K, and Batch dimension");
165-
std::optional<unsigned> batchDim =
166-
cDims->batch.empty() ? std::nullopt
167-
: encoding.mapDimToOperandIndex(cDims->batch[0]);
168-
std::optional<unsigned> mDim =
169-
cDims->m.empty() ? std::nullopt
170-
: encoding.mapDimToOperandIndex(cDims->m[0]);
171-
std::optional<unsigned> nDim =
172-
cDims->n.empty() ? std::nullopt
173-
: encoding.mapDimToOperandIndex(cDims->n[0]);
174-
std::optional<unsigned> kDim = encoding.mapDimToOperandIndex(cDims->k[0]);
175-
if (batchDim.has_value()) {
176-
encodingInfo.outerDimsPerm.push_back(batchDim.value());
177-
}
178-
if (mDim.has_value()) {
179-
encodingInfo.outerDimsPerm.push_back(mDim.value());
180-
encodingInfo.innerDimsPos.push_back(mDim.value());
181-
encodingInfo.innerTileSizes.push_back(tileMxNxK.M);
182-
}
183-
if (nDim.has_value()) {
184-
encodingInfo.outerDimsPerm.push_back(nDim.value());
185-
encodingInfo.innerDimsPos.push_back(nDim.value());
186-
encodingInfo.innerTileSizes.push_back(tileMxNxK.N);
187-
}
188-
if (kDim.has_value()) {
189-
encodingInfo.outerDimsPerm.push_back(kDim.value());
190-
encodingInfo.innerDimsPos.push_back(kDim.value());
191-
encodingInfo.innerTileSizes.push_back(tileMxNxK.K);
192-
}
193-
return encodingInfo;
194-
}
195-
196-
bool isNarrowNResult(EncodingAttr encoding) {
197-
if (encoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RESULT) {
198-
return false;
199-
}
200-
201-
return IREE::Encoding::getMatmulNarrowDim(encoding).isN();
202-
}
203-
204156
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,6 @@ class OpMaterializeEncodingPattern : public OpConversionPattern<OpTy> {
8585
/// Returns the RankedTensorType without encodings.
8686
RankedTensorType dropEncoding(RankedTensorType type);
8787

88-
struct TileMxNxK {
89-
int64_t M = 1;
90-
int64_t N = 1;
91-
int64_t K = 1;
92-
};
93-
94-
IREE::Codegen::MaterializeEncodingInfo
95-
getEncodingInfoForMatmul(IREE::Encoding::EncodingAttr encoding, int64_t rank,
96-
TileMxNxK tileMxNxK);
97-
9888
/// Utility method to convert from `set_encoding` op to `pack` operation.
9989
/// For now this takes a `paddingValue` as input. The source is also taken
10090
/// as input so that these could be used with `OpConversionPatterns`.
@@ -126,10 +116,6 @@ void populateShapeIndependentMaterializeEncodingPatterns(
126116
MaterializeEncodingTypeConverter &typeConverter,
127117
MaterializeEncodingValueFn materializeEncodingValueFn);
128118

129-
// Returns true if `encoding` represents a narrow-N matmul RESULT, e.g. the
130-
// result of a matvec.
131-
bool isNarrowNResult(IREE::Encoding::EncodingAttr encoding);
132-
133119
} // namespace mlir::iree_compiler
134120

135121
#endif // IREE_COMPILER_SRC_IREE_COMPILER_CODEGEN_COMMON_ENCODINGUTILS_H_

compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ namespace mlir::iree_compiler {
4343
#include "iree/compiler/Codegen/Common/GPU/Passes.h.inc"
4444

4545
using IREE::Codegen::MaterializeEncodingInfo;
46+
using IREE::Codegen::TileMxNxK;
4647
using IREE::Codegen::TileSwizzle;
4748

4849
static IREE::GPU::MMAAttr chooseIntrinsicMMAAttr(TypeRange eTypes,
@@ -245,10 +246,10 @@ materializeEncodingForTarget(RankedTensorType tensorType,
245246

246247
// Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
247248
// based on its operand index in the matmul.
248-
auto rank = tensorType.getRank();
249249
TileMxNxK innerTile;
250250
std::tie(innerTile.M, innerTile.N, innerTile.K) = mma.getMNKShape();
251-
auto encodingInfo = getEncodingInfoForMatmul(encoding, rank, innerTile);
251+
auto encodingInfo =
252+
IREE::Codegen::getEncodingInfoForMatmul(encoding, innerTile);
252253
auto fragment =
253254
static_cast<IREE::GPU::MMAFragment>(encoding.getOperandIndex().getInt());
254255
encodingInfo.swizzle = getSwizzle(mma, fragment);

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingIntoPackUnPack.cpp

Lines changed: 10 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,6 @@ getSwizzledShape(ArrayRef<OpFoldResult> packedShape,
6161
return newShape;
6262
}
6363

64-
static Operation *dropEncodingAndCloneOp(OpBuilder &builder, Operation *op,
65-
ValueRange convertedInputOperands,
66-
ValueRange convertedOutputOperands) {
67-
SmallVector<Value> operands;
68-
operands.append(convertedInputOperands.begin(), convertedInputOperands.end());
69-
operands.append(convertedOutputOperands.begin(),
70-
convertedOutputOperands.end());
71-
return mlir::clone(builder, op,
72-
{dropEncoding(cast<RankedTensorType>(
73-
convertedOutputOperands[0].getType()))},
74-
operands);
75-
}
76-
7764
static FailureOr<SmallVector<OpFoldResult>>
7865
getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
7966
RankedTensorType tensorType,
@@ -111,91 +98,6 @@ getInnerTileSizesOfr(OpBuilder &rewriter, Location loc,
11198
return result;
11299
}
113100

114-
RankedTensorType getExpandedType(RankedTensorType type, bool isBatched,
115-
bool isTransposed,
116-
SmallVectorImpl<ReassociationIndices> &ri) {
117-
if (!isBatched) {
118-
ri.assign({{0, 1}, {2, 3}});
119-
if (!isTransposed) {
120-
return RankedTensorType::get(
121-
{1, type.getDimSize(0), 1, type.getDimSize(1)},
122-
type.getElementType());
123-
}
124-
return RankedTensorType::get({type.getDimSize(0), 1, type.getDimSize(1), 1},
125-
type.getElementType());
126-
}
127-
128-
ri.assign({{0}, {1, 2}, {3, 4}});
129-
if (!isTransposed) {
130-
return RankedTensorType::get(
131-
{type.getDimSize(0), 1, type.getDimSize(1), 1, type.getDimSize(2)},
132-
type.getElementType());
133-
}
134-
return RankedTensorType::get(
135-
{type.getDimSize(0), type.getDimSize(1), 1, type.getDimSize(2), 1},
136-
type.getElementType());
137-
}
138-
139-
/// Given an input Value and a desired output element type, create and return
140-
/// an element-wise linalg::GenericOp that extends the input Value to the
141-
/// output element type.
142-
static Value createElementWiseExtUIOp(RewriterBase &rewriter, Value input,
143-
Location loc, Type outElemType) {
144-
auto inputType = cast<RankedTensorType>(input.getType());
145-
SmallVector<AffineMap> maps(
146-
2, rewriter.getMultiDimIdentityMap(inputType.getRank()));
147-
SmallVector<utils::IteratorType> iteratorTypes(inputType.getRank(),
148-
utils::IteratorType::parallel);
149-
auto castedType = inputType.clone(outElemType);
150-
SmallVector<OpFoldResult> inputMixedSizes =
151-
tensor::getMixedSizes(rewriter, loc, input);
152-
Value init =
153-
rewriter.create<tensor::EmptyOp>(loc, inputMixedSizes, outElemType);
154-
return rewriter
155-
.create<linalg::GenericOp>(
156-
loc, castedType, input, init, maps, iteratorTypes,
157-
[&](OpBuilder &b, Location nestedLoc, ValueRange args) {
158-
Value castRes =
159-
b.create<arith::ExtUIOp>(nestedLoc, outElemType, args[0])
160-
->getResult(0);
161-
b.create<linalg::YieldOp>(nestedLoc, castRes);
162-
})
163-
.getResult(0);
164-
}
165-
166-
/// If needed, expand and the input Value, and return the resulting input with
167-
/// the canonical mmt4d input shape. If the input element type is unsigned,
168-
/// create a producer Linalg::GenericOp on the input that unsigned extends the
169-
/// input to the output element type. This extension is required to keep the
170-
/// unsignedness information on the input for ukernels. If `transpose` is true,
171-
/// the `linalgOp`'s indexing maps are transposed.
172-
static Value getMmt4dOperand(Value value, linalg::LinalgOp linalgOp,
173-
bool transpose, RewriterBase &rewriter,
174-
SmallVectorImpl<ReassociationIndices> &ri,
175-
ArrayRef<Type> elemTypes, int operandIdx) {
176-
assert(linalgOp.getNumDpsInputs() == 2);
177-
assert(linalgOp.getNumDpsInits() == 1);
178-
auto cDims = linalg::inferContractionDims(linalgOp);
179-
Location loc = linalgOp->getLoc();
180-
Value expandedValue = value;
181-
// If vecmat with non-rhs operandIdx or matvec with non-lhs operandIdx, the
182-
// operand is a vector and must be extended
183-
if ((cDims->m.empty() && operandIdx != 1) ||
184-
(cDims->n.empty() && operandIdx != 0)) {
185-
auto type = cast<RankedTensorType>(value.getType());
186-
RankedTensorType newType = getExpandedType(
187-
type, /*isBatched=*/!cDims->batch.empty(),
188-
/*isTransposed=*/operandIdx == 2 && (transpose ^ cDims->n.empty()), ri);
189-
expandedValue =
190-
rewriter.create<tensor::ExpandShapeOp>(loc, newType, value, ri);
191-
}
192-
if (elemTypes[operandIdx].isUnsignedInteger()) {
193-
return createElementWiseExtUIOp(rewriter, expandedValue, loc,
194-
elemTypes.back());
195-
}
196-
return expandedValue;
197-
}
198-
199101
static void transposeInPlace(MaterializeEncodingInfo &info) {
200102
// Vector cases: nothing to do.
201103
if (info.innerTileSizes.size() < 2) {
@@ -297,75 +199,6 @@ FailureOr<tensor::UnPackOp> lowerUnsetEncodingToUnpackOp(
297199
encodingInfo->outerDimsPerm);
298200
}
299201

300-
static FailureOr<Operation *> lowerContractionOpWithEncoding(
301-
RewriterBase &rewriter, linalg::LinalgOp linalgOp, ValueRange operands,
302-
const MaterializeEncodingTypeConverter &typeConverter) {
303-
if (!linalgOp.hasPureTensorSemantics())
304-
return failure();
305-
306-
auto inputs = linalgOp.getDpsInputOperands();
307-
auto outputs = linalgOp.getDpsInits();
308-
309-
auto lhsType = cast<RankedTensorType>(inputs[0]->get().getType());
310-
auto rhsType = cast<RankedTensorType>(inputs[1]->get().getType());
311-
auto resultType = cast<RankedTensorType>(outputs[0].getType());
312-
auto lhsEncoding = IREE::Encoding::getEncodingAttr(lhsType);
313-
auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType);
314-
auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType);
315-
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
316-
return failure();
317-
}
318-
319-
if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS ||
320-
rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS ||
321-
resultEncoding.getOperandIndex().getValue() !=
322-
IREE::Encoding::MATMUL_RESULT) {
323-
return failure();
324-
}
325-
326-
FailureOr<MaterializeEncodingInfo> encodingInfo =
327-
typeConverter.getEncodingInfo(
328-
cast<RankedTensorType>(linalgOp->getResultTypes()[0]));
329-
330-
Operation *result;
331-
if (failed(encodingInfo)) {
332-
result = dropEncodingAndCloneOp(rewriter, linalgOp,
333-
operands.take_front(inputs.size()),
334-
operands.drop_front(inputs.size()));
335-
} else {
336-
bool transpose =
337-
typeConverter.getTransposeNarrowN() && isNarrowNResult(resultEncoding);
338-
SmallVector<Type> elemTypes = lhsEncoding.getElementTypesArray();
339-
SmallVector<ReassociationIndices> ri;
340-
Value newLhs = getMmt4dOperand(operands[0], linalgOp, transpose, rewriter,
341-
ri, elemTypes, /*operandIdx=*/0);
342-
Value newRhs = getMmt4dOperand(operands[1], linalgOp, transpose, rewriter,
343-
ri, elemTypes, /*operandIdx=*/1);
344-
Value newResult =
345-
getMmt4dOperand(operands[2], linalgOp, transpose, rewriter, ri,
346-
elemTypes, /*operandIdx=*/2);
347-
if (transpose) {
348-
std::swap(newLhs, newRhs);
349-
}
350-
Type newResultType = newResult.getType();
351-
auto cDims = IREE::Encoding::getEncodingContractionDims(lhsEncoding);
352-
if (cDims->batch.empty()) {
353-
result = rewriter.create<linalg::Mmt4DOp>(
354-
linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
355-
ValueRange{newResult});
356-
} else {
357-
result = rewriter.create<linalg::BatchMmt4DOp>(
358-
linalgOp.getLoc(), newResultType, ValueRange{newLhs, newRhs},
359-
ValueRange{newResult});
360-
}
361-
if (!ri.empty()) {
362-
result = rewriter.create<tensor::CollapseShapeOp>(
363-
linalgOp->getLoc(), operands[2].getType(), result->getResult(0), ri);
364-
}
365-
}
366-
return result;
367-
}
368-
369202
/// Utility method to convert `tensor.empty` with encoding to a `tensor.empty`
370203
/// of the materialized type.
371204
static FailureOr<Operation *>
@@ -901,8 +734,17 @@ class MaterializeContractionOp
901734

902735
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
903736
this->getTypeConverter());
737+
// TODO(hanchung): This is a transition state for moving the implementation
738+
// details to backend attributes. We won't need the function type argument
739+
// after all the backends that support encodings implement the attribute.
740+
auto getEncodingInfoWrapper =
741+
[&](RankedTensorType type) -> FailureOr<MaterializeEncodingInfo> {
742+
return converter->getEncodingInfo(type);
743+
};
904744
FailureOr<Operation *> convertedOp =
905-
lowerContractionOpWithEncoding(rewriter, op, operands, *converter);
745+
IREE::Codegen::lowerContractionOpWithEncoding(
746+
rewriter, op, operands, converter->getTransposeNarrowN(),
747+
getEncodingInfoWrapper);
906748
if (failed(convertedOp)) {
907749
return failure();
908750
}

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenTypes.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include <cstdint>
1111

1212
#include "llvm/ADT/SmallVector.h"
13+
#include "mlir/IR/BuiltinTypes.h"
1314
#include "mlir/Support/LLVM.h"
1415

1516
namespace mlir::iree_compiler::IREE::Codegen {
@@ -89,5 +90,8 @@ struct MaterializeEncodingInfo {
8990
std::optional<TileSwizzle> swizzle;
9091
};
9192

93+
using ResolveEncodingInfoFn =
94+
std::function<FailureOr<MaterializeEncodingInfo>(RankedTensorType type)>;
95+
9296
} // namespace mlir::iree_compiler::IREE::Codegen
9397
#endif // IREE_COMPILER_CODEGEN_DIALECT_CODEGEN_IR_IREECODEGENTYPES_H_

compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/BUILD.bazel

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,12 @@ iree_compiler_cc_library(
2222
],
2323
deps = [
2424
"//compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR:IREECodegenDialect",
25+
"//compiler/src/iree/compiler/Dialect/Encoding/IR",
2526
"@llvm-project//llvm:Support",
27+
"@llvm-project//mlir:ArithDialect",
2628
"@llvm-project//mlir:DialectUtils",
2729
"@llvm-project//mlir:IR",
30+
"@llvm-project//mlir:LinalgDialect",
31+
"@llvm-project//mlir:TensorDialect",
2832
],
2933
)

compiler/src/iree/compiler/Codegen/Dialect/Codegen/Utils/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ iree_cc_library(
1919
"Utils.cpp"
2020
DEPS
2121
LLVMSupport
22+
MLIRArithDialect
2223
MLIRIR
24+
MLIRLinalgDialect
25+
MLIRTensorDialect
2326
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect
27+
iree::compiler::Dialect::Encoding::IR
2428
PUBLIC
2529
)
2630

0 commit comments

Comments
 (0)