Skip to content

Commit 027557d

Browse files
committed
respect permament layout in 'setDistributeLayoutAttr'
Signed-off-by: dchigarev <[email protected]>
1 parent 15f6907 commit 027557d

File tree

4 files changed

+79
-9
lines changed

4 files changed

+79
-9
lines changed

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,15 @@ void removeLayoutAttrs(Operation *op);
104104

105105
/// Sets the DistributeLayoutAttr for a given OpOperand or OpResult by attaching
106106
/// it to the owner's dictionary attributes
107+
/// If `respectPermLayout` is true the existing permament layout
108+
/// attribute will be kept and assigned to the attribute dict instead
109+
/// of the provided layout.
107110
template <typename T,
108111
typename = std::enable_if_t<std::is_same_v<T, OpOperand> ||
109112
std::is_same_v<T, OpResult>>>
110113
void setDistributeLayoutAttr(const T &operandOrResult,
111-
const DistributeLayoutAttr layout);
114+
const DistributeLayoutAttr layout,
115+
bool respectPermLayout = false);
112116

113117
/// Set the DistributeLayoutAttr for each OpOperand and OpResult of the given
114118
/// operation. If the operation contains regions, it is also applied recursively

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1048,7 +1048,7 @@ static LogicalResult updateOp(mlir::OpBuilder &builder, mlir::Operation *op,
10481048
}
10491049
// If the result is a vector type, add a temporary layout attribute to the
10501050
// op.
1051-
xegpu::setDistributeLayoutAttr(result, layout);
1051+
xegpu::setDistributeLayoutAttr(result, layout, /*respectPermLayout*/ true);
10521052
}
10531053
return success();
10541054
}

mlir/lib/Dialect/XeGPU/Utils/XeGPUUtils.cpp

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -185,24 +185,68 @@ xegpu::getDistributeLayoutAttr(const OpOperand &opr) {
185185
return getDistributeLayoutAttr(opr.get());
186186
}
187187

188+
xegpu::DistributeLayoutAttr
189+
maybePickPermamentLayout(xegpu::DistributeLayoutAttr layout,
190+
const OpResult &result, bool respectPermLayout,
191+
mlir::Operation *owner, const std::string &name) {
192+
if (!respectPermLayout)
193+
return layout;
194+
xegpu::DistributeLayoutAttr candidate = layout;
195+
196+
if (auto loadOp = dyn_cast<xegpu::LoadGatherOp>(owner)) {
197+
if (auto perm = loadOp.getLayoutAttr())
198+
candidate = perm;
199+
}
200+
201+
return candidate;
202+
}
203+
204+
xegpu::DistributeLayoutAttr
205+
maybePickPermamentLayout(xegpu::DistributeLayoutAttr layout,
206+
const OpOperand &operand, bool respectPermLayout,
207+
mlir::Operation *owner, const std::string &name) {
208+
if (!respectPermLayout)
209+
return layout;
210+
211+
xegpu::DistributeLayoutAttr candidate = layout;
212+
unsigned idx = const_cast<OpOperand &>(operand).getOperandNumber();
213+
214+
if (auto storeOp = dyn_cast<xegpu::StoreScatterOp>(owner)) {
215+
if (idx == 0) {
216+
if (auto perm = storeOp.getLayoutAttr())
217+
candidate = perm;
218+
}
219+
}
220+
221+
return candidate;
222+
}
223+
188224
template <typename T, typename>
189225
void xegpu::setDistributeLayoutAttr(const T &operandOrResult,
190-
const DistributeLayoutAttr layout) {
226+
const DistributeLayoutAttr layout,
227+
bool respectPermLayout) {
191228
Operation *owner = operandOrResult.getOwner();
192229
std::string name = xegpu::getLayoutName(operandOrResult);
193-
if (layout && !owner->hasAttrOfType<DistributeLayoutAttr>(name))
194-
owner->setAttr(name, layout);
230+
231+
if (owner->hasAttrOfType<DistributeLayoutAttr>(name))
232+
return;
233+
234+
auto candidate = maybePickPermamentLayout(layout, operandOrResult,
235+
respectPermLayout, owner, name);
236+
237+
if (candidate)
238+
owner->setAttr(name, candidate);
195239
}
196240

197241
// Explicit instantiation for OpResult
198242
template void xegpu::setDistributeLayoutAttr<mlir::OpResult>(
199243
const mlir::OpResult &result,
200-
const mlir::xegpu::DistributeLayoutAttr layout);
244+
const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);
201245

202246
// Explicit instantiation for OpOperand
203247
template void xegpu::setDistributeLayoutAttr<mlir::OpOperand>(
204248
const mlir::OpOperand &operand,
205-
const mlir::xegpu::DistributeLayoutAttr layout);
249+
const mlir::xegpu::DistributeLayoutAttr layout, bool respectPermLayout);
206250

207251
void xegpu::setDistributeLayoutAttrs(
208252
Operation *op, function_ref<DistributeLayoutAttr(Value)> getLayoutImpl) {

mlir/test/Dialect/XeGPU/propagate-layout.mlir

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,13 +221,35 @@ gpu.module @test {
221221
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
222222
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
223223
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
224-
// CHECK: xegpu.store %[[LOAD_VEC]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
224+
// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
225+
// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
225226
// CHECK-SAME <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
226227
func.func @scatter_ops_custom_perm_layout(%src: memref<256xf16>) {
227228
%1 = arith.constant dense<1>: vector<16xi1>
228229
%offset = arith.constant dense<12> : vector<16xindex>
229230
%3 = xegpu.load %src[%offset], %1 : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
230-
xegpu.store %3, %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
231+
%4 = arith.addf %3, %3 : vector<16xf16>
232+
xegpu.store %4, %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
233+
return
234+
}
235+
}
236+
// -----
237+
gpu.module @test {
238+
// CHECK-LABEL: func.func @scatter_ops_preserve_load_perm_layout(
239+
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z]+]]: memref<256xf16>) {
240+
// CHECK: %[[MASK:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<true> : vector<16xi1>
241+
// CHECK: %[[OFFSETS:.*]] = arith.constant {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} dense<12> : vector<16xindex>
242+
// CHECK: %[[LOAD_VEC:.*]] = xegpu.load %[[ARG0]][%[[OFFSETS]]], %[[MASK]] <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}>
243+
// CHECK-SAME: {layout_result_0 = #xegpu.layout<lane_layout = [16], lane_data = [1]>} : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
244+
// CHECK: %[[ADD_RES:.*]] = arith.addf %[[LOAD_VEC]], %[[LOAD_VEC]] {layout_result_0 = #xegpu.layout<lane_layout = [8], lane_data = [1]>} : vector<16xf16>
245+
// CHECK: xegpu.store %[[ADD_RES]], %[[ARG0]][%[[OFFSETS]]], %[[MASK]]
246+
// CHECK-SAME <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
247+
func.func @scatter_ops_preserve_load_perm_layout(%src: memref<256xf16>) {
248+
%1 = arith.constant dense<1>: vector<16xi1>
249+
%offset = arith.constant dense<12> : vector<16xindex>
250+
%3 = xegpu.load %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [16], lane_data = [1]>}> : memref<256xf16>, vector<16xindex>, vector<16xi1> -> vector<16xf16>
251+
%4 = arith.addf %3, %3 : vector<16xf16>
252+
xegpu.store %4, %src[%offset], %1 <{layout = #xegpu.layout<lane_layout = [8], lane_data = [1]>}> : vector<16xf16>, memref<256xf16>, vector<16xindex>, vector<16xi1>
231253
return
232254
}
233255
}

0 commit comments

Comments
 (0)