Skip to content

Commit b1b52a6

Browse files
committed
Address reviwer comments.
1 parent 9752a4d commit b1b52a6

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ class CreateNdDescToXeVMPattern
204204
val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
205205
return val;
206206
};
207-
// Offsets are not supported not (0 is used).
207+
// Offsets are not supported (0 is used).
208208
offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
209209
offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
210210
// Get shape values from op fold results.
@@ -379,10 +379,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
379379
LogicalResult
380380
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
381381
ConversionPatternRewriter &rewriter) const override {
382-
Value offsets = adaptor.getOffsets();
383-
if (!offsets)
382+
Value offset = adaptor.getOffsets();
383+
if (!offset)
384384
return rewriter.notifyMatchFailure(op,
385-
"Expected offsets to be provided.");
385+
"Expected offset to be provided.");
386386
auto loc = op.getLoc();
387387
auto ctxt = rewriter.getContext();
388388
auto tdescTy = op.getTensorDescType();
@@ -433,16 +433,16 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
433433
basePtrI64);
434434
}
435435
Value mask = adaptor.getMask();
436-
if (dyn_cast<VectorType>(offsets.getType())) {
436+
if (dyn_cast<VectorType>(offset.getType())) {
437437
// Offset needs be scalar. Single element vector is converted to scalar
438438
// by type converter.
439439
return rewriter.notifyMatchFailure(op,
440-
"Expected offsets to be a scalar.");
440+
"Expected offset to be a scalar.");
441441
} else {
442-
// If offsets are provided, we add them to the base pointer.
443-
// Offsets are in number of elements, we need to multiply by
442+
// If offset is provided, we add them to the base pointer.
443+
// Offset is in number of elements, we need to multiply by
444444
// element byte size.
445-
basePtrI64 = addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
445+
basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
446446
}
447447
// Convert base pointer (i64) to LLVM pointer type.
448448
Value basePtrLLVM =

0 commit comments

Comments
 (0)