@@ -101,8 +101,7 @@ struct LayoutInfo {
101101
102102 bool isAssigned () const { return storage != nullptr ; }
103103
104- LayoutInfo transpose (ArrayRef<int64_t > permutation,
105- LayoutKind layoutKind) const ;
104+ LayoutInfo transpose (ArrayRef<int64_t > permutation) const ;
106105
107106 SmallVector<int > getLaneLayout () const ;
108107
@@ -169,9 +168,9 @@ LayoutInfo LayoutInfo::join(const LayoutInfo &lhs, const LayoutInfo &rhs) {
169168 llvm_unreachable (" Join should not be triggered by layout propagation." );
170169}
171170
172- // / Construct a new layout with the transposed lane layout and lane data.
173- LayoutInfo LayoutInfo::transpose (ArrayRef< int64_t > permutation,
174- LayoutKind layoutKind ) const {
171+ // / Construct a new layout with the transposed inst_data or lane_layout,
172+ // / lane_data.
173+ LayoutInfo LayoutInfo::transpose (ArrayRef< int64_t > permutation ) const {
175174 if (!isAssigned ())
176175 return {};
177176 // Check if the permutation is valid.
@@ -190,19 +189,19 @@ LayoutInfo LayoutInfo::transpose(ArrayRef<int64_t> permutation,
190189 SmallVector<int32_t > laneData;
191190 SmallVector<int32_t > instData;
192191 for (int64_t idx : permutation) {
193- if (layoutKind == LayoutKind::Lane ) {
192+ if (getLaneLayout (). size () ) {
194193 laneLayout.push_back (static_cast <int32_t >(getLaneLayout ()[idx]));
195194 laneData.push_back (static_cast <int32_t >(getLaneData ()[idx]));
196- } else if (layoutKind == LayoutKind::InstData)
195+ }
196+ if (getInstData ().size ())
197197 instData.push_back (static_cast <int32_t >(getInstData ()[idx]));
198198 }
199199 xegpu::LayoutAttr layoutAttr;
200- if (layoutKind == LayoutKind::Lane) {
200+ if (getLaneLayout (). size ())
201201 layoutAttr =
202202 xegpu::LayoutAttr::get (storage.getContext (), laneLayout, laneData);
203- } else if (layoutKind == LayoutKind::InstData) {
203+ if (getInstData (). size ())
204204 layoutAttr = xegpu::LayoutAttr::get (storage.getContext (), instData);
205- }
206205 return LayoutInfo (layoutAttr);
207206}
208207
@@ -748,7 +747,7 @@ void LayoutInfoPropagation::visitLoadNdOp(
748747 if (auto transpose = load.getTranspose ()) {
749748 load.emitWarning (" Transpose effect is not expected for LoadNdOp at "
750749 " LayoutInfoPropagation stage." );
751- tensorDescLayout = valueLayout.transpose (transpose.value (), layoutKind );
750+ tensorDescLayout = valueLayout.transpose (transpose.value ());
752751 }
753752 // Propagate the new layout to the tensor descriptor operand.
754753 propagateIfChanged (operands[0 ], operands[0 ]->meet (tensorDescLayout));
@@ -763,8 +762,7 @@ void LayoutInfoPropagation::visitTransposeOp(
763762 LayoutInfo resultLayout = results[0 ]->getValue ();
764763 if (!resultLayout.isAssigned ())
765764 return ;
766- LayoutInfo newLayout =
767- resultLayout.transpose (transpose.getPermutation (), layoutKind);
765+ LayoutInfo newLayout = resultLayout.transpose (transpose.getPermutation ());
768766 // Propagate the new layout to the vector operand.
769767 propagateIfChanged (operands[0 ], operands[0 ]->meet (newLayout));
770768}
@@ -1207,18 +1205,17 @@ struct XeGPUPropagateLayoutPass final
12071205
12081206void XeGPUPropagateLayoutPass::runOnOperation () {
12091207 LayoutKind layoutKind;
1210- if (this ->layoutKind == " lane" )
1208+ if (this ->layoutKind == " lane" ) {
12111209 layoutKind = LayoutKind::Lane;
1212- else if (this ->layoutKind == " inst" )
1210+ } else if (this ->layoutKind == " inst" ) {
12131211 layoutKind = LayoutKind::InstData;
1214- else {
1215- signalPassFailure ();
1212+ } else {
12161213 getOperation ()->emitError (" Unsupported layout kind option: " +
12171214 this ->layoutKind );
1215+ signalPassFailure ();
12181216 return ;
12191217 }
12201218 RunLayoutInfoPropagation analysis (getOperation (), layoutKind);
1221- // auto &analysis = getAnalysis<RunLayoutInfoPropagation>();
12221219 // Print the analysis result and exit. (for debugging purposes)
12231220 if (printOnly) {
12241221 auto &os = llvm::outs ();
0 commit comments