Skip to content

Commit d628a22

Browse files
committed
Address feedback
1 parent c7124c1 commit d628a22

File tree

3 files changed

+16
-39
lines changed

3 files changed

+16
-39
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 12 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1225,9 +1225,7 @@ struct WgToSgVectorTransposeOp
12251225
LogicalResult
12261226
matchAndRewrite(vector::TransposeOp op, OneToNOpAdaptor adaptor,
12271227
ConversionPatternRewriter &rewriter) const override {
1228-
VectorType resultType = dyn_cast<VectorType>(op.getResult().getType());
1229-
if (!resultType)
1230-
return failure();
1228+
VectorType resultType = op.getResultVectorType();
12311229

12321230
ArrayRef<int64_t> wgShape = resultType.getShape();
12331231
xegpu::DistributeLayoutAttr layout =
@@ -1242,9 +1240,7 @@ struct WgToSgVectorTransposeOp
12421240

12431241
SmallVector<int64_t> sourceSgLayout =
12441242
sourceLayout.getEffectiveSgLayoutAsInt();
1245-
SmallVector<int64_t> sourceSgData = sourceLayout.getEffectiveSgDataAsInt();
12461243
SmallVector<int64_t> resultSgLayout = layout.getEffectiveSgLayoutAsInt();
1247-
SmallVector<int64_t> resultSgData = layout.getEffectiveSgDataAsInt();
12481244
DenseI32ArrayAttr sourceOrder = sourceLayout.getOrder();
12491245
DenseI32ArrayAttr resultOrder = layout.getOrder();
12501246

@@ -1253,37 +1249,20 @@ struct WgToSgVectorTransposeOp
12531249
op, "Both source and result must have order attributes");
12541250
}
12551251

1256-
SmallVector<int64_t> sourceOrderVec = llvm::to_vector(
1257-
llvm::map_range(sourceOrder.asArrayRef(),
1258-
[](int32_t idx) { return static_cast<int64_t>(idx); }));
1259-
SmallVector<int64_t> resultOrderVec = llvm::to_vector(
1260-
llvm::map_range(resultOrder.asArrayRef(),
1261-
[](int32_t idx) { return static_cast<int64_t>(idx); }));
1262-
12631252
ArrayRef<int64_t> permutation = op.getPermutation();
1264-
size_t expectedSize = permutation.size();
1265-
if (sourceSgLayout.size() != expectedSize ||
1266-
sourceSgData.size() != expectedSize ||
1267-
resultSgLayout.size() != expectedSize ||
1268-
resultSgData.size() != expectedSize ||
1269-
sourceOrderVec.size() != expectedSize ||
1270-
resultOrderVec.size() != expectedSize) {
1253+
size_t permutationSize = permutation.size();
1254+
if (sourceSgLayout.size() != permutationSize ||
1255+
resultSgLayout.size() != permutationSize) {
12711256
return rewriter.notifyMatchFailure(
1272-
op, "All layouts and permutation must have the same rank");
1257+
op, "Layouts and permutation must have the same rank");
12731258
}
12741259

1275-
// Check that sgLayout, sgData & order are properly transposed for operand
1260+
// Check that sgLayout, sgData & order are properly transposed for source
12761261
// and result
1277-
for (size_t i = 0; i < permutation.size(); ++i) {
1278-
int64_t srcDim = permutation[i];
1279-
if (resultSgLayout[i] != sourceSgLayout[srcDim] ||
1280-
resultSgData[i] != sourceSgData[srcDim] ||
1281-
resultOrderVec[i] != sourceOrderVec[srcDim]) {
1282-
return rewriter.notifyMatchFailure(
1283-
op, "Result layout is not a valid transpose of source layout "
1284-
"according to permutation");
1285-
}
1286-
}
1262+
if (!layout.isTransposeOf(sourceLayout, permutation))
1263+
return rewriter.notifyMatchFailure(
1264+
op, "Result layout is not a valid transpose of source layout "
1265+
"according to permutation");
12871266

12881267
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
12891268
VectorType newResultType =
@@ -1292,10 +1271,8 @@ struct WgToSgVectorTransposeOp
12921271
for (auto src : adaptor.getVector()) {
12931272
auto newTranspose = vector::TransposeOp::create(
12941273
rewriter, op.getLoc(), newResultType, src, permutation);
1295-
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1296-
!layout.getEffectiveInstDataAsInt().empty())
1297-
xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
1298-
layout.dropSgLayoutAndData());
1274+
xegpu::setDistributeLayoutAttr(newTranspose->getResult(0),
1275+
layout.dropSgLayoutAndData());
12991276
newTransposeOps.push_back(newTranspose.getResult());
13001277
}
13011278

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,9 @@ gpu.module @test_distribution {
121121
// CHECK-LABEL: vector_transpose
122122
gpu.func @vector_transpose(%src: memref<256x128xf32>) {
123123
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
124-
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1], order =[0, 1]>>
124+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
125125
%load = xegpu.load_nd %tdesc[0, 0]
126-
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [1, 16], lane_data = [1, 1], order =[0, 1]>>
126+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 16], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
127127
-> vector<256x128xf32>
128128
// CHECK-COUNT-2: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<32x16xf32> to vector<16x32xf32>
129129
// CHECK-NOT: vector.transpose

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -467,9 +467,9 @@ gpu.module @test_distribution {
467467
// CHECK-LABEL: vector_transpose
468468
gpu.func @vector_transpose(%src: memref<256x32xf32>) {
469469
%tdesc = xegpu.create_nd_tdesc %src : memref<256x32xf32>
470-
-> !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[0, 1]>>
470+
-> !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
471471
%load = xegpu.load_nd %tdesc[0, 0]
472-
: !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[0, 1]>>
472+
: !xegpu.tensor_desc<256x32xf32, #xegpu.layout<sg_layout = [4, 8], sg_data = [64, 32], lane_layout = [16, 1], lane_data = [1, 1], order =[0, 1]>>
473473
-> vector<256x32xf32>
474474
//CHECK: vector.transpose {{.*}}, [1, 0] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1], order = [1, 0]>} : vector<64x32xf32> to vector<32x64xf32>
475475
%trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 64], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x32xf32> to vector<32x256xf32>

0 commit comments

Comments
 (0)