Skip to content

Commit 77f3261

Browse files
committed
Add check
1 parent 2739fa8 commit 77f3261

File tree

2 files changed

+30
-16
lines changed

2 files changed

+30
-16
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,6 +1000,23 @@ struct WgToSgVectorShapeCastOp
10001000
if (!onlyUnitDims(srcType.getShape(), sgShape))
10011001
return failure();
10021002

1003+
// Check to verify that if expanding dims, the input operand's layout
1004+
// is sliceAttr and if reducing dims, result's layout is
1005+
// sliceAttr.
1006+
int srcRank = srcType.getRank();
1007+
int dstRank = sgShape.size();
1008+
if (dstRank > srcRank) {
1009+
// Expanding dims: input operand's layout must be a SliceAttr
1010+
auto srcLayout = xegpu::getDistributeLayoutAttr(op.getSource());
1011+
if (!srcLayout || !isa<xegpu::SliceAttr>(srcLayout))
1012+
return failure();
1013+
} else if (dstRank < srcRank) {
1014+
// Reducing dims: result's layout must be a SliceAttr
1015+
auto resLayout = xegpu::getDistributeLayoutAttr(op.getResult());
1016+
if (!resLayout || !isa<xegpu::SliceAttr>(resLayout))
1017+
return failure();
1018+
}
1019+
10031020
SmallVector<Value> newShapeCastOps;
10041021
for (auto src : adaptor.getSource()) {
10051022
auto newShapeCast =

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -408,23 +408,20 @@ gpu.module @test_distribution {
408408
}
409409

410410
// CHECK-LABEL: vector_shape_cast
411-
gpu.func @vector_shape_cast(%src: memref<256x128xf32>) {
412-
%tdesc = xegpu.create_nd_tdesc %src : memref<256x128xf32>
413-
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
414-
%load = xegpu.load_nd %tdesc[0, 0]
415-
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32], lane_layout = [1, 16], lane_data = [1, 1]>>
416-
-> vector<256x128xf32>
417-
//CHECK: vector.shape_cast {{.*}} : vector<32x32xf32> to vector<32x1x32x1xf32>
418-
%cast = vector.shape_cast %load {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 4, 1], sg_data = [32, 1, 32, 1]>} : vector<256x128xf32> to vector<256x1x128x1xf32>
411+
gpu.func @vector_shape_cast() {
412+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} dense<10> : vector<128xindex>
413+
%step = vector.step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
414+
%muli = arith.muli %cst, %step {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>, dims = [0, 1, 2]>} : vector<128xindex>
415+
//CHECK: vector.shape_cast {{.*}} : vector<32xindex> to vector<1x1x1x32xindex>
416+
%shape_cast = vector.shape_cast %muli {layout_result_0 = #xegpu.layout<sg_layout = [8, 1, 1, 4], sg_data = [1, 1, 1, 32]>} : vector<128xindex> to vector<1x1x1x128xindex>
419417
gpu.return
420418
}
421419

422-
// CHECK-LABEL: broadcast
423-
// CHECK-SAME: %[[ARG_0:.*]]: index, %[[ARG_1:.*]]: index
424-
gpu.func @broadcast(%arg0: index, %arg1: index) {
425-
%muli = arith.muli %arg0, %arg1 : index
426-
// CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
427-
%broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
428-
gpu.return
429-
}
420+
// CHECK-LABEL: vector_broadcast
421+
gpu.func @vector_broadcast(%arg0: index, %arg1: index) {
422+
%muli = arith.muli %arg0, %arg1 : index
423+
// CHECK: vector.broadcast {{.*}} : index to vector<1x1x1x32xindex>
424+
%broadcast = vector.broadcast %muli {layout_result_0 = #xegpu.layout<sg_layout = [4, 2, 6, 1], sg_data = [1, 1, 1, 32]>} : index to vector<4x2x6x32xindex>
425+
gpu.return
426+
}
430427
}

0 commit comments

Comments
 (0)