@@ -77,19 +77,8 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
7777 const SmallVector<OpFoldResult> &originalOffsets,
7878 const SmallVector<Value> &localOffset,
7979 const SmallVector<int64_t > &distUnitBaseAddr) const {
80-
81- Value constOffsetX =
82- rewriter.create <arith::ConstantIndexOp>(loc, distUnitBaseAddr[0 ]);
83- Value constOffsetY =
84- rewriter.create <arith::ConstantIndexOp>(loc, distUnitBaseAddr[1 ]);
85-
86- Value offsetX =
87- rewriter.createOrFold <index::AddOp>(loc, localOffset[0 ], constOffsetX);
88- Value offsetY =
89- rewriter.createOrFold <index::AddOp>(loc, localOffset[1 ], constOffsetY);
90-
91- size_t lastDimIndex = originalOffsets.size () - 1 ;
92- size_t secondLastDimIndex = lastDimIndex - 1 ;
80+ assert (localOffset.size () == distUnitBaseAddr.size () &&
81+ " localOffset and distUnitBaseAddr must have the same rank" );
9382
9483 // Convert originalOffsets to Value
9584 auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
@@ -102,18 +91,20 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
10291 llvm_unreachable (" Unsupported OpFoldResult kind" );
10392 };
10493
105- Value origOffsetX =
106- getValueFromOpFoldResult (originalOffsets[secondLastDimIndex]);
107- Value origOffsetY = getValueFromOpFoldResult (originalOffsets[lastDimIndex]);
108- Value globalOffsetX =
109- rewriter.createOrFold <index::AddOp>(loc, origOffsetX, offsetX);
110- Value globalOffsetY =
111- rewriter.createOrFold <index::AddOp>(loc, origOffsetY, offsetY);
112-
11394 SmallVector<OpFoldResult> globalOffsets (originalOffsets.begin (),
11495 originalOffsets.end ());
115- globalOffsets[secondLastDimIndex] = globalOffsetX;
116- globalOffsets[lastDimIndex] = globalOffsetY;
96+ size_t rank = localOffset.size ();
97+ for (size_t i = 0 ; i < rank; ++i) {
98+ size_t dimIdx = originalOffsets.size () - rank + i;
99+ Value constOffset =
100+ rewriter.create <arith::ConstantIndexOp>(loc, distUnitBaseAddr[i]);
101+ Value offset =
102+ rewriter.createOrFold <index::AddOp>(loc, localOffset[i], constOffset);
103+ Value origOffset = getValueFromOpFoldResult (originalOffsets[dimIdx]);
104+ Value globalOffset =
105+ rewriter.createOrFold <index::AddOp>(loc, origOffset, offset);
106+ globalOffsets[dimIdx] = globalOffset;
107+ }
117108
118109 return globalOffsets;
119110 }
@@ -283,7 +274,7 @@ struct WgToSgDpasOp : public OpConversionPattern<xegpu::DpasOp> {
283274 tmpC = rewriter.create <xegpu::DpasOp>(
284275 loc, resTy, operands,
285276 llvm::ArrayRef<NamedAttribute>(
286- {" layout " , originalLayout.dropSgLayoutAndData ()}));
277+ {" layout_result_0 " , originalLayout.dropSgLayoutAndData ()}));
287278 newDpasOps.push_back (tmpC);
288279 }
289280 }
0 commit comments