@@ -29,7 +29,6 @@ static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
2929 if (!isNarrowNResult (encoding)) {
3030 return tensorType;
3131 }
32- auto newIndex = encoding.getOperandIndex ();
3332 SmallVector<int64_t > newOriginalShape (tensorType.getShape ());
3433 auto userIndexingMaps = encoding.getUserIndexingMaps ();
3534 SmallVector<AffineMap> maps;
@@ -70,23 +69,19 @@ static RankedTensorType transposeIfNarrowNResult(RankedTensorType tensorType) {
7069 for (auto &map : maps) {
7170 map = map.compose (permutation);
7271 }
73- SmallVector<Attribute> newMaps;
74- for (auto map : maps) {
75- newMaps.push_back (AffineMapAttr::get (map));
76- }
77- ArrayAttr newIndexingMaps = ArrayAttr::get (context, newMaps);
7872 auto elemType = tensorType.getElementType ();
79- OpBuilder builder (context );
73+ auto operandIndex = encoding. getOperandIndex (). getInt ( );
8074
81- auto opTypeAttr = IREE::Encoding::EncodingOpTypeAttr::get (
82- context, IREE::Encoding::EncodingOpType::matmul);
8375 // TODO(#17718): Handle the broadcast map for transpose cases. It is on the
8476 // experimental path, so it is not clear what needs to be done here. For now
8577 // just use the original map for the new encoding.
78+ std::optional<AffineMap> newBcastMap;
79+ if (encoding.getBcastMap ()) {
80+ newBcastMap = encoding.getBcastMap ().getValue ();
81+ }
8682 auto newEncoding = IREE::Encoding::EncodingAttr::get (
87- context, newIndex, opTypeAttr, encoding.getElementTypes (),
88- newIndexingMaps, encoding.getBcastMap (),
89- DenseI64ArrayAttr::get (context, newRoundDimsTo));
83+ context, operandIndex, encoding.getOpType ().getValue (),
84+ encoding.getElementTypesArray (), maps, newBcastMap, newRoundDimsTo);
9085 return RankedTensorType::get (newShape, elemType, newEncoding);
9186}
9287
0 commit comments