Skip to content

Commit 2da2c6d

Browse files
committed
save work
1 parent 5a683b4 commit 2da2c6d

File tree

2 files changed

+88
-11
lines changed

2 files changed

+88
-11
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,8 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
187187
// Check if the permutation is valid.
188188
llvm::SmallSet<int64_t, 4> seen(permutation.begin(), permutation.end());
189189
bool hasDuplicates = seen.size() != permutation.size();
190-
bool withinRange = llvm::all_of(permutation, [&](size_t idx) {
191-
return idx >= 0 && idx < permutation.size();
190+
bool withinRange = llvm::all_of(permutation, [&](int64_t idx) {
191+
return idx >= 0 && idx < static_cast<int64_t>(permutation.size());
192192
});
193193

194194
if (!withinRange || hasDuplicates) {
@@ -577,7 +577,7 @@ void LayoutInfoPropagation::visitShapeCastOp(
577577
int sourceDistributedDim =
578578
sourceShape[0] % xegpu::targetinfo::subgroupSize == 0
579579
? 0
580-
: (sourceShape[1] % xegpu::targetinfo::subgroupSize ? 1 : -1);
580+
: (sourceShape[1] % xegpu::targetinfo::subgroupSize == 0 ? 1 : -1);
581581
if (sourceDistributedDim == -1) {
582582
shapeCast.emitWarning(
583583
"Source vector can not be evenly distributed across lanes.");
@@ -597,16 +597,17 @@ void LayoutInfoPropagation::visitShapeCastOp(
597597
// [subgroupSize][1]. Otherwise, data is shared accross lanes (broadcasted).
598598
// We use slice attribute for the broadcast case.
599599
int64_t distributedDim = resultLaneLayout[0] == 1 ? 1 : 0;
600-
xegpu::LayoutAttr plainLayout = xegpu::LayoutAttr::get(
601-
shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
602600
if (resultShape[distributedDim] % xegpu::targetinfo::subgroupSize != 0) {
601+
xegpu::LayoutAttr parentLayout = xegpu::LayoutAttr::get(
602+
shapeCast->getContext(), resultLaneLayout, resultLayout.getLaneData());
603603
xegpu::SliceAttr sliceLayout = xegpu::SliceAttr::get(
604-
shapeCast->getContext(), plainLayout,
604+
shapeCast->getContext(), parentLayout,
605605
DenseI64ArrayAttr::get(shapeCast->getContext(), {distributedDim}));
606606
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(sliceLayout)));
607607
return;
608608
}
609-
propagateIfChanged(operands[0], operands[0]->meet(LayoutInfo(plainLayout)));
609+
propagateIfChanged(operands[0], operands[0]->meet(getDefaultSIMTLayoutInfo(
610+
shapeCast.getSourceVectorType())));
610611
}
611612

612613
/// Propagate the layout of the result tensor to the source tensor descriptor
@@ -711,9 +712,9 @@ void LayoutInfoPropagation::visitVectorBitcastOp(
711712
bool isNarrowing = inElemTyBitWidth > outElemTyBitWidth;
712713
int bitCastRatio = isNarrowing ? inElemTyBitWidth / outElemTyBitWidth
713714
: outElemTyBitWidth / inElemTyBitWidth;
714-
ArrayRef<int> sourceLaneLayout =
715+
SmallVector<int> sourceLaneLayout =
715716
resultLayout.getLaneLayout(); // Lane layout does not change for bitcast.
716-
ArrayRef<int> outData = resultLayout.getLaneData();
717+
SmallVector<int> outData = resultLayout.getLaneData();
717718

718719
// TODO: Currently we assume that bitcasts does not require cross lane
719720
// communication. So each lane must own the required number of elements to

mlir/test/Dialect/XeGPU/propagate-layout.mlir

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ func.func @prefetch_1d(%arg0: memref<256xf16>){
455455
}
456456

457457
// -----
458-
// CHECK-LABEL: func.func @test_scf_while_and_condition(
458+
// CHECK-LABEL: func.func @scf_while_and_condition(
459459
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf32>, %[[ARG1:[0-9a-zA-Z]+]]: memref<256xf32>) {
460460
// CHECK: %{{.*}}:3 = scf.while ({{.*}}) : (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>)
461461
// CHECK-SAME: -> (vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
@@ -464,7 +464,7 @@ func.func @prefetch_1d(%arg0: memref<256xf16>){
464464
// CHECK-NEXT: ^bb0(%{{.*}}: vector<16xf32>, %{{.*}}: i32, %{{.*}}: !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>):
465465
// CHECK: scf.yield {{.*}} : vector<16xf32>, i32, !xegpu.tensor_desc<16xf32, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
466466
// CHECK-NEXT: } attributes {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
467-
func.func @test_scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<256xf32>) {
467+
func.func @scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<256xf32>) {
468468
%c0 = arith.constant 0 : i32
469469
%c16 = arith.constant 16 : i32
470470
%c256 = arith.constant 256 : i32
@@ -486,3 +486,79 @@ func.func @test_scf_while_and_condition(%arg0: memref<256xf32>, %arg1: memref<25
486486
}
487487
return
488488
}
489+
490+
// -----
491+
// CHECK-LABEL: func.func @vector_shape_cast_2d_to_1d_dim0_distributed(
492+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x1xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>>,
493+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
494+
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]]
495+
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>} :
496+
// CHECK-SAME: !xegpu.tensor_desc<16x1xf16, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 1]>> -> vector<16x1xf16>
497+
// CHECK-NEXT: %{{.*}} = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
498+
// CHECK-SAME: : vector<16x1xf16> to vector<16xf16>
499+
func.func @vector_shape_cast_2d_to_1d_dim0_distributed(%arg0: !xegpu.tensor_desc<16x1xf16>, %arg1: !xegpu.tensor_desc<16xf16>) {
500+
%c0 = arith.constant 0 : index
501+
%3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x1xf16> -> vector<16x1xf16>
502+
%2 = vector.shape_cast %3 : vector<16x1xf16> to vector<16xf16>
503+
xegpu.store_nd %2, %arg1 : vector<16xf16>, !xegpu.tensor_desc<16xf16>
504+
return
505+
}
506+
507+
// -----
508+
// CHECK-LABEL: func.func @vector_shape_cast_2d_to_1d_dim1_distributed(
509+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<1x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
510+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16xf16, #xegpu.layout<lane_layout = [16], lane_data = [1]>>) {
511+
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
512+
// CHECK-SAME: !xegpu.tensor_desc<1x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<1x16xf16>
513+
// CHECK: %{{.*}} = vector.shape_cast %[[LOAD]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>}
514+
// CHECK-SAME: vector<1x16xf16> to vector<16xf16>
515+
func.func @vector_shape_cast_2d_to_1d_dim1_distributed(%arg0: !xegpu.tensor_desc<1x16xf16>, %arg1: !xegpu.tensor_desc<16xf16>) {
516+
%c0 = arith.constant 0 : index
517+
%3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<1x16xf16> -> vector<1x16xf16>
518+
%2 = vector.shape_cast %3 : vector<1x16xf16> to vector<16xf16>
519+
xegpu.store_nd %2, %arg1 : vector<16xf16>, !xegpu.tensor_desc<16xf16>
520+
return
521+
}
522+
523+
// -----
524+
// CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim1_distributed(
525+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
526+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
527+
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[ARG0]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
528+
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
529+
// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
530+
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} [0] : vector<16x16xf16> to vector<16xf16>
531+
// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
532+
// CHECK-SAME: vector<16xf16> to vector<1x16xf16>
533+
func.func @vector_shape_cast_1d_to_2d_dim1_distributed(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
534+
%c0 = arith.constant 0 : index
535+
%cst = arith.constant dense<0.0000> : vector<16xf16>
536+
%3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
537+
%4 = vector.multi_reduction <add>, %3, %cst [0] : vector<16x16xf16> to vector<16xf16>
538+
%2 = vector.shape_cast %4 : vector<16xf16> to vector<1x16xf16>
539+
%5 = vector.broadcast %2 : vector<1x16xf16> to vector<16x16xf16>
540+
xegpu.store_nd %5, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
541+
return
542+
}
543+
544+
// -----
545+
// CHECK-LABEL: func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(
546+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>,
547+
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z]+]]: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>) {
548+
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %arg0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
549+
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
550+
// CHECK-NEXT: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %{{[0-9a-zA-Z]+}}
551+
// CHECK-SAME: {layout_result_0 = #xegpu.slice<#xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>, dims = [1]>} [1]
552+
// CHECK-SAME: vector<16x16xf16> to vector<16xf16>
553+
// CHECK-NEXT: %[[CAST:.*]] = vector.shape_cast %[[REDUCE]] {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>}
554+
// CHECK-SAME: vector<16xf16> to vector<16x1xf16>
555+
func.func @vector_shape_cast_1d_to_2d_dim0_broadcasted(%arg0: !xegpu.tensor_desc<16x16xf16>, %arg1: !xegpu.tensor_desc<16x16xf16>) {
556+
%c0 = arith.constant 0 : index
557+
%cst = arith.constant dense<0.0000> : vector<16xf16>
558+
%3 = xegpu.load_nd %arg0 : !xegpu.tensor_desc<16x16xf16> -> vector<16x16xf16>
559+
%4 = vector.multi_reduction <add>, %3, %cst [1] : vector<16x16xf16> to vector<16xf16>
560+
%2 = vector.shape_cast %4 : vector<16xf16> to vector<16x1xf16>
561+
%5 = vector.broadcast %2 : vector<16x1xf16> to vector<16x16xf16>
562+
xegpu.store_nd %5, %arg1 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16>
563+
return
564+
}

0 commit comments

Comments
 (0)