@@ -1225,9 +1225,7 @@ struct WgToSgVectorTransposeOp
12251225 LogicalResult
12261226 matchAndRewrite (vector::TransposeOp op, OneToNOpAdaptor adaptor,
12271227 ConversionPatternRewriter &rewriter) const override {
1228- VectorType resultType = dyn_cast<VectorType>(op.getResult ().getType ());
1229- if (!resultType)
1230- return failure ();
1228+ VectorType resultType = op.getResultVectorType ();
12311229
12321230 ArrayRef<int64_t > wgShape = resultType.getShape ();
12331231 xegpu::DistributeLayoutAttr layout =
@@ -1242,9 +1240,7 @@ struct WgToSgVectorTransposeOp
12421240
12431241 SmallVector<int64_t > sourceSgLayout =
12441242 sourceLayout.getEffectiveSgLayoutAsInt ();
1245- SmallVector<int64_t > sourceSgData = sourceLayout.getEffectiveSgDataAsInt ();
12461243 SmallVector<int64_t > resultSgLayout = layout.getEffectiveSgLayoutAsInt ();
1247- SmallVector<int64_t > resultSgData = layout.getEffectiveSgDataAsInt ();
12481244 DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder ();
12491245 DenseI32ArrayAttr resultOrder = layout.getOrder ();
12501246
@@ -1253,37 +1249,20 @@ struct WgToSgVectorTransposeOp
12531249 op, " Both source and result must have order attributes" );
12541250 }
12551251
1256- SmallVector<int64_t > sourceOrderVec = llvm::to_vector (
1257- llvm::map_range (sourceOrder.asArrayRef (),
1258- [](int32_t idx) { return static_cast <int64_t >(idx); }));
1259- SmallVector<int64_t > resultOrderVec = llvm::to_vector (
1260- llvm::map_range (resultOrder.asArrayRef (),
1261- [](int32_t idx) { return static_cast <int64_t >(idx); }));
1262-
12631252 ArrayRef<int64_t > permutation = op.getPermutation ();
1264- size_t expectedSize = permutation.size ();
1265- if (sourceSgLayout.size () != expectedSize ||
1266- sourceSgData.size () != expectedSize ||
1267- resultSgLayout.size () != expectedSize ||
1268- resultSgData.size () != expectedSize ||
1269- sourceOrderVec.size () != expectedSize ||
1270- resultOrderVec.size () != expectedSize) {
1253+ size_t permutationSize = permutation.size ();
1254+ if (sourceSgLayout.size () != permutationSize ||
1255+ resultSgLayout.size () != permutationSize) {
12711256 return rewriter.notifyMatchFailure (
1272- op, " All layouts and permutation must have the same rank" );
1257+ op, " Layouts and permutation must have the same rank" );
12731258 }
12741259
1275- // Check that sgLayout, sgData & order are properly transposed for operand
1260+ // Check that sgLayout, sgData & order are properly transposed for source
12761261 // and result
1277- for (size_t i = 0 ; i < permutation.size (); ++i) {
1278- int64_t srcDim = permutation[i];
1279- if (resultSgLayout[i] != sourceSgLayout[srcDim] ||
1280- resultSgData[i] != sourceSgData[srcDim] ||
1281- resultOrderVec[i] != sourceOrderVec[srcDim]) {
1282- return rewriter.notifyMatchFailure (
1283- op, " Result layout is not a valid transpose of source layout "
1284- " according to permutation" );
1285- }
1286- }
1262+ if (!layout.isTransposeOf (sourceLayout, permutation))
1263+ return rewriter.notifyMatchFailure (
1264+ op, " Result layout is not a valid transpose of source layout "
1265+ " according to permutation" );
12871266
12881267 SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
12891268 VectorType newResultType =
@@ -1292,10 +1271,8 @@ struct WgToSgVectorTransposeOp
12921271 for (auto src : adaptor.getVector ()) {
12931272 auto newTranspose = vector::TransposeOp::create (
12941273 rewriter, op.getLoc (), newResultType, src, permutation);
1295- if (!layout.getEffectiveLaneLayoutAsInt ().empty () ||
1296- !layout.getEffectiveInstDataAsInt ().empty ())
1297- xegpu::setDistributeLayoutAttr (newTranspose->getResult (0 ),
1298- layout.dropSgLayoutAndData ());
1274+ xegpu::setDistributeLayoutAttr (newTranspose->getResult (0 ),
1275+ layout.dropSgLayoutAndData ());
12991276 newTransposeOps.push_back (newTranspose.getResult ());
13001277 }
13011278
0 commit comments