Skip to content

Commit b01086e

Browse files
committed
Address comments.
1 parent dea2933 commit b01086e

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,11 @@ class CreateNdDescToXeVMPattern
151151
auto loc = op.getLoc();
152152
auto source = op.getSource();
153153
// Op is lowered to a code sequence that populates payload.
154-
// Payload is a 8xi32 vector.
154+
// Payload is a 8xi32 vector. Offset to individual fields are defined in
155+
// NdTdescOffset enum.
155156
Type payloadElemTy = rewriter.getI32Type();
156-
Type i64Ty = rewriter.getI64Type();
157157
VectorType payloadTy = VectorType::get(8, payloadElemTy);
158+
Type i64Ty = rewriter.getI64Type();
158159
// 4xi64 view is used for inserting the base pointer.
159160
VectorType payloadI64Ty = VectorType::get(4, i64Ty);
160161
// Initialize payload to zero.
@@ -180,12 +181,12 @@ class CreateNdDescToXeVMPattern
180181
// If source is a memref, we need to extract the aligned pointer as index.
181182
// Pointer type is passed as i32 or i64 by type converter.
182183
if (sourceMemrefTy) {
183-
baseAddr =
184-
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
185184
if (!sourceMemrefTy.hasStaticShape()) {
186185
op.emitError() << "Expected static memref shape.";
187186
return failure();
188187
}
188+
baseAddr =
189+
memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
189190
} else {
190191
baseAddr = adaptor.getSource();
191192
}
@@ -198,8 +199,8 @@ class CreateNdDescToXeVMPattern
198199
};
199200
// Offsets can be either 2D or not provided (0 is used).
200201
if (mixedOffsets.size() == 2) {
201-
offsetW = createOffset(mixedOffsets, rank - 1);
202-
offsetH = createOffset(mixedOffsets, rank - 2);
202+
offsetW = createOffset(mixedOffsets, 1);
203+
offsetH = createOffset(mixedOffsets, 0);
203204
} else if (mixedOffsets.size() == 0) {
204205
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
205206
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
@@ -208,8 +209,8 @@ class CreateNdDescToXeVMPattern
208209
"Expected 2D offsets or no offsets.");
209210
}
210211
// Get shape values from op fold results.
211-
baseShapeW = createOffset(mixedSizes, rank - 1);
212-
baseShapeH = createOffset(mixedSizes, rank - 2);
212+
baseShapeW = createOffset(mixedSizes, 1);
213+
baseShapeH = createOffset(mixedSizes, 0);
213214
if (sourceMemrefTy)
214215
// Cast index to i64.
215216
baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);

0 commit comments

Comments
 (0)