Skip to content

Commit 4c0fd90

Browse files
authored
[Encoding][DT][NFC] Simplify the codes for EncodingAttr creation. (#19147)
Signed-off-by: hanhanW <[email protected]>
1 parent 2a2bd06 commit 4c0fd90

File tree

1 file changed

+7
-12
lines changed

1 file changed

+7
-12
lines changed

compiler/src/iree/compiler/Codegen/Common/EncodingUtils.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
2929
if (!isNarrowNResult(encoding)) {
3030
return tensorType;
3131
}
32-
auto newIndex = encoding.getOperandIndex();
3332
SmallVector<int64_t> newOriginalShape(tensorType.getShape());
3433
auto userIndexingMaps = encoding.getUserIndexingMaps();
3534
SmallVector<AffineMap> maps;
@@ -70,23 +69,19 @@ static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
7069
for (auto &map : maps) {
7170
map = map.compose(permutation);
7271
}
73-
SmallVector<Attribute> newMaps;
74-
for (auto map : maps) {
75-
newMaps.push_back(AffineMapAttr::get(map));
76-
}
77-
ArrayAttr newIndexingMaps = ArrayAttr::get(context, newMaps);
7872
auto elemType = tensorType.getElementType();
79-
OpBuilder builder(context);
73+
auto operandIndex = encoding.getOperandIndex().getInt();
8074

81-
auto opTypeAttr = IREE::Encoding::EncodingOpTypeAttr::get(
82-
context, IREE::Encoding::EncodingOpType::matmul);
8375
// TODO(#17718): Handle the broadcast map for transpose cases. It is on the
8476
// experimental path, so it is not clear what needs to be done here. For now
8577
// just use the original map for the new encoding.
78+
std::optional<AffineMap> newBcastMap;
79+
if (encoding.getBcastMap()) {
80+
newBcastMap = encoding.getBcastMap().getValue();
81+
}
8682
auto newEncoding = IREE::Encoding::EncodingAttr::get(
87-
context, newIndex, opTypeAttr, encoding.getElementTypes(),
88-
newIndexingMaps, encoding.getBcastMap(),
89-
DenseI64ArrayAttr::get(context, newRoundDimsTo));
83+
context, operandIndex, encoding.getOpType().getValue(),
84+
encoding.getElementTypesArray(), maps, newBcastMap, newRoundDimsTo);
9085
return RankedTensorType::get(newShape, elemType, newEncoding);
9186
}
9287

0 commit comments

Comments
 (0)