Skip to content

Commit ba98c14

Browse files
committed
Address comments
1 parent f47b1c6 commit ba98c14

File tree

2 files changed

+16
-19
lines changed

2 files changed

+16
-19
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def XeGPUPropagateLayout : Pass<"xegpu-propagate-layout"> {
4747
Option<
4848
"layoutKind", "layout-kind", "std::string",
4949
/*default=*/"\"lane\"",
50-
"Propagate a `sg` / `inst` / `lane` level of xegpu layouts.">
50+
"Propagate `inst` / `lane` level of xegpu layouts.">
5151
];
5252
}
5353

mlir/lib/Dialect/XeGPU/Transforms/XeGPUPropagateLayout.cpp

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,7 @@ struct LayoutInfo {
101101

102102
bool isAssigned() const { return storage != nullptr; }
103103

104-
LayoutInfo transpose(ArrayRef<int64_t> permutation,
105-
LayoutKind layoutKind) const;
104+
LayoutInfo transpose(ArrayRef<int64_t> permutation) const;
106105

107106
SmallVector<int> getLaneLayout() const;
108107

@@ -169,9 +168,9 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
169168
llvm_unreachable("Join should not be triggered by layout propagation.");
170169
}
171170

172-
/// Construct a new layout with the transposed lane layout and lane data.
173-
LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation,
174-
LayoutKind layoutKind) const {
171+
/// Construct a new layout with the transposed inst_data or lane_layout,
172+
/// lane_data.
173+
LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation) const {
175174
if (!isAssigned())
176175
return {};
177176
// Check if the permutation is valid.
@@ -190,19 +189,19 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation,
190189
SmallVector<int32_t> laneData;
191190
SmallVector<int32_t> instData;
192191
for (int64_t idx : permutation) {
193-
if (layoutKind == LayoutKind::Lane) {
192+
if (getLaneLayout().size()) {
194193
laneLayout.push_back(static_cast<int32_t>(getLaneLayout()[idx]));
195194
laneData.push_back(static_cast<int32_t>(getLaneData()[idx]));
196-
} else if (layoutKind == LayoutKind::InstData)
195+
}
196+
if (getInstData().size())
197197
instData.push_back(static_cast<int32_t>(getInstData()[idx]));
198198
}
199199
xegpu::LayoutAttr layoutAttr;
200-
if (layoutKind == LayoutKind::Lane) {
200+
if (getLaneLayout().size())
201201
layoutAttr =
202202
xegpu::LayoutAttr::get(storage.getContext(), laneLayout, laneData);
203-
} else if (layoutKind == LayoutKind::InstData) {
203+
if (getInstData().size())
204204
layoutAttr = xegpu::LayoutAttr::get(storage.getContext(), instData);
205-
}
206205
return LayoutInfo(layoutAttr);
207206
}
208207

@@ -748,7 +747,7 @@ void LayoutInfoPropagation::visitLoadNdOp(
748747
if (auto transpose = load.getTranspose()) {
749748
load.emitWarning("Transpose effect is not expected for LoadNdOp at "
750749
"LayoutInfoPropagation stage.");
751-
tensorDescLayout = valueLayout.transpose(transpose.value(), layoutKind);
750+
tensorDescLayout = valueLayout.transpose(transpose.value());
752751
}
753752
// Propagate the new layout to the tensor descriptor operand.
754753
propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
@@ -763,8 +762,7 @@ void LayoutInfoPropagation::visitTransposeOp(
763762
LayoutInfo resultLayout = results[0]->getValue();
764763
if (!resultLayout.isAssigned())
765764
return;
766-
LayoutInfo newLayout =
767-
resultLayout.transpose(transpose.getPermutation(), layoutKind);
765+
LayoutInfo newLayout = resultLayout.transpose(transpose.getPermutation());
768766
// Propagate the new layout to the vector operand.
769767
propagateIfChanged(operands[0], operands[0]->meet(newLayout));
770768
}
@@ -1207,18 +1205,17 @@ struct XeGPUPropagateLayoutPass final
12071205

12081206
void XeGPUPropagateLayoutPass::runOnOperation() {
12091207
LayoutKind layoutKind;
1210-
if (this->layoutKind == "lane")
1208+
if (this->layoutKind == "lane") {
12111209
layoutKind = LayoutKind::Lane;
1212-
else if (this->layoutKind == "inst")
1210+
} else if (this->layoutKind == "inst") {
12131211
layoutKind = LayoutKind::InstData;
1214-
else {
1215-
signalPassFailure();
1212+
} else {
12161213
getOperation()->emitError("Unsupported layout kind option: " +
12171214
this->layoutKind);
1215+
signalPassFailure();
12181216
return;
12191217
}
12201218
RunLayoutInfoPropagation analysis(getOperation(), layoutKind);
1221-
// auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
12221219
// Print the analysis result and exit. (for debugging purposes)
12231220
if (printOnly) {
12241221
auto &os = llvm::outs();

0 commit comments

Comments
 (0)