Skip to content

Commit 85e0c4e

Browse files
committed
handle simple cases
1 parent f0e270d commit 85e0c4e

File tree

2 files changed

+65
-40
lines changed

2 files changed

+65
-40
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1510,24 +1510,8 @@ struct VectorExtractStridedSliceDistribution
15101510
return rewriter.notifyMatchFailure(
15111511
warpOp, "Expecting source to be distributed in a single dimension.");
15121512
int64_t distributedDim = distributedDims[0];
1513-
// Check if the distributed dimension is fully extracted. If so, we exit
1514-
// early becuase this case already handled by vector distribution patterns.
1515-
// Distributed dimension is fully extracted if:
1516-
// 1) Distributed dim comes after all the extracted dimensions.
1517-
// 2) Or, the size extacted along the distributed dimension is equal the
1518-
// size of that dim in source vector.
1519-
auto extractedSizes = extractOp.getSizes();
1520-
if (distributedDim >= static_cast<int64_t>(extractedSizes.size()))
1521-
return rewriter.notifyMatchFailure(
1522-
warpOp, "Distributed dimension is fully extracted, skipping.");
1523-
1524-
int distrDimExtractedSize =
1525-
cast<IntegerAttr>(extractOp.getSizes()[distributedDim]).getInt();
15261513
int sourceDistrDimSize =
15271514
extractOp.getSourceVectorType().getShape()[distributedDim];
1528-
if (distrDimExtractedSize == sourceDistrDimSize)
1529-
return rewriter.notifyMatchFailure(
1530-
warpOp, "Distributed dimension is fully extracted, skipping.");
15311515

15321516
auto sourceLayout =
15331517
xegpu::getDistributeLayoutAttr(extractOp->getOpOperand(0));
@@ -1635,14 +1619,7 @@ struct VectorInsertStridedSliceDistribution
16351619
return rewriter.notifyMatchFailure(
16361620
insertOp, "distributed dimension must be in the last k (i.e. source "
16371621
"rank) dims of dest vector");
1638-
// If the distributed dimension is fully inserted, skip. This case is
1639-
// already handled by vector distribution patterns.
1640-
int64_t destDistrDimSize = destType.getDimSize(destDistributedDim);
16411622
int64_t srcDistrDimSize = srcType.getDimSize(sourceDistributedDim);
1642-
if (srcDistrDimSize == destDistrDimSize)
1643-
return rewriter.notifyMatchFailure(
1644-
insertOp, "distributed dimension is fully inserted. This case "
1645-
"is handled by vector distribution.");
16461623
// Obtain the source and dest layouts.
16471624
auto destLayout = xegpu::getDistributeLayoutAttr(insertOp->getOpOperand(1));
16481625
auto sourceLayout =

mlir/test/Dialect/XeGPU/subgroup-distribute-unit.mlir

Lines changed: 65 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -607,25 +607,25 @@ gpu.func @vector_shapecast_unsupported(%laneid: index) {
607607
}
608608

609609

610-
// CHECK-LABEL: gpu.func @vector_extract_strided_slice_outer_distributed
611-
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) {
612-
// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<32x16xf32>
613-
// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32>
614-
// CHECK: }
615-
// CHECK-NEXT: %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
616-
// CHECK-NEXT: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32>
617-
// CHECK-NEXT: "some_use"(%[[T2]]) : (vector<1x16xf32>) -> ()
618-
gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) {
619-
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) {
620-
%0 = "some_def"() : () -> (vector<32x16xf32>)
621-
%1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1],
622-
layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
623-
layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
610+
// CHECK-LABEL: gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted
611+
// CHECK-NEXT: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<8x1xf32>, vector<24x1xf32>) {
612+
// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<24x16xf32>
613+
// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<8x16xf32>, vector<24x16xf32>
614+
// CHECK-NEXT: }
615+
// CHECK-NEXT: %[[T1:.*]] = vector.extract_strided_slice %[[W]]#1
616+
// CHECK-SAME: {offsets = [8, 0], sizes = [8, 1], strides = [1, 1]} : vector<24x1xf32> to vector<8x1xf32>
617+
// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<8x1xf32>) -> ()
618+
gpu.func @vector_extract_strided_slice_distributed_dim_fully_extracted(%laneid: index) {
619+
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<8x1xf32>) {
620+
%0 = "some_def"() : () -> (vector<24x16xf32>)
621+
%1 = vector.extract_strided_slice %0 { offsets = [8, 0], sizes = [8, 16], strides = [1, 1],
622+
layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
623+
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
624624
}
625-
: vector<32x16xf32> to vector<16x16xf32>
626-
gpu.yield %1 : vector<16x16xf32>
625+
: vector<24x16xf32> to vector<8x16xf32>
626+
gpu.yield %1 : vector<8x16xf32>
627627
}
628-
"some_use"(%r) : (vector<1x16xf32>) -> ()
628+
"some_use"(%r) : (vector<8x1xf32>) -> ()
629629
gpu.return
630630
}
631631

@@ -651,6 +651,28 @@ gpu.func @vector_extract_strided_slice_inner_distributed(%laneid: index) {
651651
gpu.return
652652
}
653653

654+
// CHECK-LABEL: gpu.func @vector_extract_strided_slice_outer_distributed
655+
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<1x16xf32>, vector<2x16xf32>) {
656+
// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<32x16xf32>
657+
// CHECK: gpu.yield %{{.*}}, %[[S]] : vector<16x16xf32>, vector<32x16xf32>
658+
// CHECK: }
659+
// CHECK-NEXT: %[[T1:.*]] = vector.extract %[[W]]#1[1] : vector<16xf32> from vector<2x16xf32>
660+
// CHECK-NEXT: %[[T2:.*]] = vector.shape_cast %[[T1]] : vector<16xf32> to vector<1x16xf32>
661+
// CHECK-NEXT: "some_use"(%[[T2]]) : (vector<1x16xf32>) -> ()
662+
gpu.func @vector_extract_strided_slice_outer_distributed(%laneid: index) {
663+
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<1x16xf32>) {
664+
%0 = "some_def"() : () -> (vector<32x16xf32>)
665+
%1 = vector.extract_strided_slice %0 { offsets = [16], sizes = [16], strides = [1],
666+
layout_operand_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>,
667+
layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>
668+
}
669+
: vector<32x16xf32> to vector<16x16xf32>
670+
gpu.yield %1 : vector<16x16xf32>
671+
}
672+
"some_use"(%r) : (vector<1x16xf32>) -> ()
673+
gpu.return
674+
}
675+
654676
// CHECK-LABEL: gpu.func @vector_extract_strided_slice_1d
655677
// CHECK: %[[W:.*]]:2 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<2xf32>, vector<4xf32>) {
656678
// CHECK: %[[S:.*]] = "some_def"() : () -> vector<64xf32>
@@ -709,6 +731,32 @@ gpu.func @vector_extract_strided_slice_unsopported_source(%laneid: index) {
709731
gpu.return
710732
}
711733

734+
735+
// CHECK-LABEL: gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted
736+
// CHECK-NEXT: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x1xf32>, vector<16x1xf32>, vector<64x1xf32>) {
737+
// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>
738+
// CHECK-NEXT: %[[D:.*]] = "some_def"() : () -> vector<64x16xf32>
739+
// CHECK: gpu.yield %{{.*}}, %[[S]], %[[D]] : vector<64x16xf32>, vector<16x16xf32>, vector<64x16xf32>
740+
// CHECK-NEXT: }
741+
// CHECK-NEXT: %[[T1:.*]] = vector.insert_strided_slice %[[W]]#1, %[[W]]#2
742+
// CHECK-SAME: {offsets = [24, 0], strides = [1, 1]} : vector<16x1xf32> into vector<64x1xf32>
743+
// CHECK-NEXT: "some_use"(%[[T1]]) : (vector<64x1xf32>) -> ()
744+
gpu.func @vector_insert_strided_slice_distributed_dim_fully_inserted(%laneid: index) {
745+
%r = gpu.warp_execute_on_lane_0(%laneid)[16] -> (vector<64x1xf32>) {
746+
%0 = "some_def"() : () -> (vector<16x16xf32>)
747+
%1 = "some_def"() : () -> (vector<64x16xf32>)
748+
%2 = vector.insert_strided_slice %0, %1 { offsets = [24, 0], strides = [1, 1],
749+
layout_operand_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
750+
layout_operand_1 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>,
751+
layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>
752+
}
753+
: vector<16x16xf32> into vector<64x16xf32>
754+
gpu.yield %2 : vector<64x16xf32>
755+
}
756+
"some_use"(%r) : (vector<64x1xf32>) -> ()
757+
gpu.return
758+
}
759+
712760
// CHECK-LABEL: gpu.func @vector_insert_strided_slice_inner_distributed
713761
// CHECK: %[[W:.*]]:3 = gpu.warp_execute_on_lane_0(%{{.*}})[16] -> (vector<64x2xf32>, vector<16x1xf32>, vector<64x2xf32>) {
714762
// CHECK-NEXT: %[[S:.*]] = "some_def"() : () -> vector<16x16xf32>

0 commit comments

Comments
 (0)