Skip to content

Commit 0d0047d

Browse files
committed
support permament layout attr in 'get/setDistributeLayout' helpers
Signed-off-by: dchigarev <[email protected]>
1 parent b58eb10 commit 0d0047d

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,22 @@ mlir::xegpu::getDistributedVectorType(VectorType originalType,
105105
std::string xegpu::getLayoutName(const OpOperand &operand) {
106106
const StringRef prefix("layout_operand_");
107107
unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
108-
return llvm::formatv("{0}{1}", prefix, idx).str();
108+
auto owner = operand.getOwner();
109+
auto tempLayout = llvm::formatv("{0}{1}", prefix, idx).str();
110+
if (isa<StoreScatterOp>(operand.getOwner()) && idx == 0 &&
111+
!owner->hasAttr(tempLayout))
112+
return "layout";
113+
return tempLayout;
109114
}
110115

111116
std::string xegpu::getLayoutName(const OpResult result) {
112117
const StringRef prefix = "layout_result_";
113-
return llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
118+
auto owner = result.getOwner();
119+
auto tempLayout =
120+
llvm::formatv("{0}{1}", prefix, result.getResultNumber()).str();
121+
if (isa<LoadGatherOp>(owner) && !owner->hasAttr(tempLayout))
122+
return "layout";
123+
return tempLayout;
114124
}
115125

116126
xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
@@ -144,6 +154,11 @@ xegpu::DistributeLayoutAttr xegpu::getDistributeLayoutAttr(const Value value) {
144154
std::string layoutName = getLayoutName(result);
145155
if (defOp->hasAttr(layoutName))
146156
return defOp->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
157+
158+
// check for "permament" layout only after "temporary" layout name lookup
159+
// for backward compatibility
160+
if (auto loadGatherOp = dyn_cast<xegpu::LoadGatherOp>(defOp))
161+
return loadGatherOp.getLayoutAttr();
147162
}
148163

149164
if (auto arg = dyn_cast<BlockArgument>(value)) {
@@ -171,6 +186,13 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
171186
std::string layoutName = xegpu::getLayoutName(opr);
172187
if (op->hasAttr(layoutName))
173188
return op->getAttrOfType<xegpu::DistributeLayoutAttr>(layoutName);
189+
190+
// check for "permament" layout only after "temporary" layout name lookup
191+
// for backward compatibility
192+
if (auto storeScatterOp = dyn_cast<xegpu::StoreScatterOp>(op))
193+
if (auto layout = storeScatterOp.getLayoutAttr())
194+
return layout;
195+
174196
return getDistributeLayoutAttr(opr.get());
175197
}
176198

mlir/test/Dialect/XeGPU/propagate-layout.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ func.func @extf_truncf(%arg0: !xegpu.tensor_desc<8x16xf16>, %arg1: !xegpu.tensor
9797
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
9898
// CHECK-NEXT: %[[T2:.*]] = xegpu.create_tdesc %[[ARG1]], %[[CST]] : memref<256xf16>, vector<16xindex> ->
9999
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>
100-
// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}
100+
// CHECK-NEXT: %{{.*}} = xegpu.load %[[T2]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}>
101101
// CHECK-SAME: !xegpu.tensor_desc<16x16xf16, #xegpu.scatter_tdesc_attr<chunk_size = 16 : i64>, #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>>, vector<16xi1> -> vector<16x16xf16>
102102
func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256xf16>, %arg2: memref<8x16xf32>) {
103103
%c0 = arith.constant 0 : index
@@ -122,7 +122,7 @@ func.func @load_gather_with_chunksize(%arg0: memref<8x16xf16>, %arg1: memref<256
122122
// CHECK-NEXT: %[[CST0:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
123123
// CHECK-NEXT: %[[T0:.*]] = xegpu.create_tdesc %[[ARG0]], %[[CST]] : memref<256xf32>, vector<16xindex> ->
124124
// CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>
125-
// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} :
125+
// CHECK-NEXT: %{{.*}} = xegpu.load %[[T0]], %[[CST0]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> :
126126
// CHECK-SAME: !xegpu.tensor_desc<16xf32, #xegpu.scatter_tdesc_attr<>, #xegpu.layout<lane_layout = [16], lane_data = [1]>>, vector<16xi1> -> vector<16xf32>
127127
func.func @load_gather_1d(%arg0: memref<256xf32>, %arg1: !xegpu.tensor_desc<16xf32>) {
128128
%cst = arith.constant dense<[0, 16, 32, 48, 64, 80, 96, 112, 128, 144, 160, 176, 192, 208, 224, 240]> : vector<16xindex>
@@ -167,8 +167,8 @@ func.func @store_scatter_1d(%arg0: vector<16xf32>, %arg1: memref<256xf32>) {
167167
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
168168
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
169169
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
170-
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}>
171-
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
170+
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64,
171+
// CHECK-SAME: layout = #xegpu.layout<lane_layout = [16, 1], lane_data = [1, 2]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16x8xf16>
172172
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{chunk_size = 8 : i64}> : vector<16x8xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
173173
func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
174174
%1 = arith.constant dense<1>: vector<16xi1>
@@ -186,7 +186,7 @@ func.func @scatter_ops_chunksize(%src: memref<256xf16>) {
186186
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
187187
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
188188
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
189-
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
189+
// CHECK-SAME: <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
190190
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]] : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
191191
func.func @scatter_ops(%src: memref<256xf16>) {
192192
%1 = arith.constant dense<1>: vector<16xi1>

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -285,8 +285,8 @@ gpu.module @test_distribution {
285285
// CHECK: %[[VAL:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<2.550000e+01> : vector<8xf16>
286286
// CHECK: %[[CST:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<0> : vector<8xindex>
287287
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<inst_data = [8]>} dense<true> : vector<8xi1>
288-
// CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>}>
289-
// CHECK-SAME: {layout_operand_0 = #xegpu.layout<inst_data = [8]>, layout_operand_2 = #xegpu.layout<inst_data = [8]>,
288+
// CHECK: xegpu.store %[[VAL]], %[[ARG0]][%[[CST]]], %[[MASK]] <{chunk_size = 1 : i64, l1_hint = #xegpu.cache_hint<cached>, layout = #xegpu.layout<inst_data = [8]>}>
289+
// CHECK-SAME: {layout_operand_2 = #xegpu.layout<inst_data = [8]>,
290290
// CHECK-SAME: layout_operand_3 = #xegpu.layout<inst_data = [8]>}
291291
// CHECK-SAME: : vector<8xf16>, memref<256xf16>, vector<8xindex>, vector<8xi1>
292292
%val = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [32], sg_data = [8], inst_data = [8]>} dense<25.5> : vector<256xf16>

0 commit comments

Comments
 (0)