@@ -70,23 +70,6 @@ namespace {
7070struct WgToSgCreateNdOp : public OpConversionPattern <xegpu::CreateNdDescOp> {
7171 using OpConversionPattern<xegpu::CreateNdDescOp>::OpConversionPattern;
7272
73- // Helper to extract mixed offsets into a Value array
74- SmallVector<Value> extractOffsets (ConversionPatternRewriter &rewriter,
75- xegpu::CreateNdDescOp op) const {
76- llvm::SmallVector<Value> offsets;
77- auto staticOffsets = op.getStaticOffsets ();
78- auto dynamicOffsets = op.getOffsets ();
79-
80- for (size_t i = 0 , j = 0 ; i != staticOffsets.size (); i++) {
81- if (ShapedType::isDynamic (staticOffsets[i]))
82- offsets.push_back (dynamicOffsets[j++]);
83- else
84- offsets.push_back (rewriter.create <arith::ConstantIndexOp>(
85- op.getLoc (), staticOffsets[i]));
86- }
87- return offsets;
88- }
89-
9073 // Convert linear subgroup ID to 2D coordinates
9174 // TODO: Delinearize for nD
9275 SmallVector<Value> delinearizeSubgroupId (ConversionPatternRewriter &rewriter,
@@ -99,7 +82,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
9982 // Calculate offset for each subgroup
10083 SmallVector<OpFoldResult>
10184 calculateGlobalOffsets (ConversionPatternRewriter &rewriter, Location loc,
102- const SmallVector<Value > &originalOffsets,
85+ const SmallVector<OpFoldResult > &originalOffsets,
10386 const SmallVector<Value> &localOffset,
10487 const SmallVector<int64_t > &distUnitBaseAddr) const {
10588
@@ -116,10 +99,24 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
11699 size_t lastDimIndex = originalOffsets.size () - 1 ;
117100 size_t secondLastDimIndex = lastDimIndex - 1 ;
118101
119- Value globalOffsetX = rewriter.createOrFold <index::AddOp>(
120- loc, originalOffsets[secondLastDimIndex], offsetX);
121- Value globalOffsetY = rewriter.createOrFold <index::AddOp>(
122- loc, originalOffsets[lastDimIndex], offsetY);
102+ // Convert originalOffsets to Value
103+ auto getValueFromOpFoldResult = [&](OpFoldResult ofr) -> Value {
104+ if (auto val = ofr.dyn_cast <Value>())
105+ return val;
106+ if (auto attr = ofr.dyn_cast <Attribute>()) {
107+ int64_t staticOffset = cast<IntegerAttr>(attr).getInt ();
108+ return rewriter.create <arith::ConstantIndexOp>(loc, staticOffset);
109+ }
110+ llvm_unreachable (" Unsupported OpFoldResult kind" );
111+ };
112+
113+ Value origOffsetX =
114+ getValueFromOpFoldResult (originalOffsets[secondLastDimIndex]);
115+ Value origOffsetY = getValueFromOpFoldResult (originalOffsets[lastDimIndex]);
116+ Value globalOffsetX =
117+ rewriter.createOrFold <index::AddOp>(loc, origOffsetX, offsetX);
118+ Value globalOffsetY =
119+ rewriter.createOrFold <index::AddOp>(loc, origOffsetY, offsetY);
123120
124121 SmallVector<OpFoldResult> globalOffsets (originalOffsets.begin (),
125122 originalOffsets.end ());
@@ -172,7 +169,7 @@ struct WgToSgCreateNdOp : public OpConversionPattern<xegpu::CreateNdDescOp> {
172169 rewriter.createOrFold <index::MulOp>(loc, sgIds[i], sgDataDim[i]);
173170 }
174171
175- SmallVector<Value > originalOffsets = extractOffsets (rewriter, op );
172+ SmallVector<OpFoldResult > originalOffsets = op. getMixedOffsets ( );
176173
177174 xegpu::TensorDescType newTdescTy =
178175 xegpu::TensorDescType::get (ctx, sgShape, elemTy, tdescTy.getEncoding (),
0 commit comments