Skip to content

Commit b1857a2

Browse files
committed
address more comments
1 parent 272f512 commit b1857a2

File tree

3 files changed

+77
-89
lines changed

3 files changed

+77
-89
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 61 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -365,10 +365,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
365365

366366
// Add a builder that creates
367367
// offset * elemByteSize + baseAddr
368-
static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
369-
Value baseAddr, Value offset, int64_t elemByteSize) {
368+
static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
369+
Location loc, Value baseAddr, Value offset,
370+
int64_t elemByteSize) {
370371
Value byteSize = arith::ConstantIntOp::create(
371-
rewriter, loc, rewriter.getI64Type(), elemByteSize);
372+
rewriter, loc, baseAddr.getType(), elemByteSize);
372373
Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
373374
Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
374375
return newAddr;
@@ -443,7 +444,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
443444
// If offset is provided, we add them to the base pointer.
444445
// Offset is in number of elements, we need to multiply by
445446
// element byte size.
446-
basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
447+
basePtrI64 =
448+
addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
447449
}
448450
// Convert base pointer (i64) to LLVM pointer type.
449451
Value basePtrLLVM =
@@ -516,7 +518,7 @@ class CreateMemDescOpPattern final
516518
LogicalResult
517519
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
518520
ConversionPatternRewriter &rewriter) const override {
519-
TypedValue<MemRefType> src = op.getSource();
521+
520522
auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
521523

522524
// Create the result MemRefType with the same shape, element type, and
@@ -525,7 +527,7 @@ class CreateMemDescOpPattern final
525527

526528
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
527529
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
528-
Value(src), zero, ValueRange());
530+
op.getSource(), zero, ValueRange());
529531
rewriter.replaceOp(op, viewOp);
530532
return success();
531533
}
@@ -587,88 +589,74 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
587589
Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(
588590
rewriter, loc, basePtrStruct);
589591

590-
// Convert base pointer (ptr) to i64
591-
Value basePtrI64 = arith::IndexCastUIOp::create(
592-
rewriter, loc, rewriter.getI64Type(), basePtrLLVM);
592+
// Convert base pointer (ptr) to i32
593+
Value basePtrI32 = arith::IndexCastUIOp::create(
594+
rewriter, loc, rewriter.getI32Type(), basePtrLLVM);
593595

594596
Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
595597
linearOffset = arith::IndexCastUIOp::create(
596-
rewriter, loc, rewriter.getI64Type(), linearOffset);
597-
basePtrI64 =
598-
addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize);
598+
rewriter, loc, rewriter.getI32Type(), linearOffset);
599+
basePtrI32 = addOffsetToBaseAddr(rewriter, loc, basePtrI32, linearOffset,
600+
elemByteSize);
599601

600-
// convert base pointer (i64) to LLVM pointer type
602+
// convert base pointer (i32) to LLVM pointer type
601603
basePtrLLVM =
602-
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
604+
LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
603605

604-
// if the size of valOrResVecTy is 1, it lowers to a scalar load/store
605-
// operation. LLVM load/store does not support vector of size 1, so we need
606-
// to handle this case separately.
607-
if (valOrResVecTy.getNumElements() == 1) {
608-
Type scalarTy = valOrResVecTy.getElementType();
609-
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
610-
Value loadOp =
611-
LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
612-
rewriter.replaceOp(op, loadOp);
613-
} else {
614-
LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
615-
rewriter.eraseOp(op);
616-
}
617-
return success();
618-
} else {
606+
if (op.getSubgroupBlockIoAttr()) {
619607
// if the attribute 'subgroup_block_io' is set to true, it lowers to
620608
// xevm.blockload
621-
auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
622-
bool subgroup_block_io = static_cast<bool>(subgroupBlockIoAttr);
623-
624-
// BlockLoadOp only supports integer types, so we need to bitcast
625-
// Get integer type with matching bit width
626-
Type elemTy = valOrResVecTy.getElementType();
627-
int64_t bitWidth = elemTy.getIntOrFloatBitWidth();
628-
Type intElemTy = rewriter.getIntegerType(bitWidth);
609+
610+
Type intElemTy = rewriter.getIntegerType(elemBitWidth);
629611
VectorType intVecTy =
630612
VectorType::get(valOrResVecTy.getShape(), intElemTy);
631613

632-
if (subgroup_block_io) {
633-
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
634-
Value loadOp =
635-
xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
636-
if (intVecTy != valOrResVecTy) {
637-
loadOp =
638-
vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
639-
}
640-
rewriter.replaceOp(op, loadOp);
641-
} else {
642-
Value dataToStore = adaptor.getData();
643-
if (valOrResVecTy != intVecTy) {
644-
dataToStore =
645-
vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
646-
}
647-
xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
648-
nullptr);
649-
rewriter.eraseOp(op);
614+
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
615+
Value loadOp =
616+
xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
617+
if (intVecTy != valOrResVecTy) {
618+
loadOp =
619+
vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
650620
}
621+
rewriter.replaceOp(op, loadOp);
651622
} else {
652-
// if the result is 1D vector, if the vector direction is Column, then
653-
// the
654-
// memory descriptor should be treated as column major
655-
auto chipOpt = xegpu::getChipStr(op);
656-
if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
657-
// the lowering only works for pvc and bmg
658-
return rewriter.notifyMatchFailure(
659-
op, "The lowering is specific to pvc or bmg.");
623+
Value dataToStore = adaptor.getData();
624+
if (valOrResVecTy != intVecTy) {
625+
dataToStore =
626+
vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
660627
}
628+
xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
629+
nullptr);
630+
rewriter.eraseOp(op);
631+
}
632+
return success();
633+
}
661634

662-
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
663-
Value loadOp =
664-
LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
665-
rewriter.replaceOp(op, loadOp);
666-
} else {
667-
LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
668-
rewriter.eraseOp(op);
669-
}
635+
if (valOrResVecTy.getNumElements() >= 1) {
636+
auto chipOpt = xegpu::getChipStr(op);
637+
if (!chipOpt || (*chipOpt != "pvc" && *chipOpt != "bmg")) {
638+
// the lowering for chunk load only works for pvc and bmg
639+
return rewriter.notifyMatchFailure(
640+
op, "The lowering is specific to pvc or bmg.");
670641
}
671642
}
643+
644+
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
645+
// if the size of valOrResVecTy is 1, it lowers to a scalar load/store
646+
// operation. LLVM load/store does not support vector of size 1, so we
647+
// need to handle this case separately.
648+
auto scalarTy = valOrResVecTy.getElementType();
649+
LLVM::LoadOp loadOp;
650+
if (valOrResVecTy.getNumElements() == 1)
651+
loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
652+
else
653+
loadOp =
654+
LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
655+
rewriter.replaceOp(op, loadOp);
656+
} else {
657+
LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
658+
rewriter.eraseOp(op);
659+
}
672660
return success();
673661
}
674662
};
@@ -715,8 +703,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
715703
op, "Expected element type bit width to be multiple of 8.");
716704
elemByteSize = elemBitWidth / 8;
717705
}
718-
basePtrI64 =
719-
addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
706+
basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
707+
elemByteSize);
720708
}
721709
}
722710
// Default memory space is global.

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ isValidGatherScatterBufferParams(Type offsetsTy, Type maskTy,
174174
}
175175

176176
LogicalResult
177-
IsValidStoreMatrixParams(VectorType dataTy, MemDescType mdescTy,
178-
UnitAttr subgroup_block_io,
179-
function_ref<InFlightDiagnostic()> emitError) {
177+
IsValidMatrixOpParams(VectorType dataTy, MemDescType mdescTy,
178+
UnitAttr subgroup_block_io,
179+
function_ref<InFlightDiagnostic()> emitError) {
180180

181181
if (!dataTy) {
182182
if (subgroup_block_io)
@@ -1107,8 +1107,8 @@ LogicalResult LoadMatrixOp::verify() {
11071107
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
11081108
MemDescType mdescTy = getMemDesc().getType();
11091109

1110-
return IsValidStoreMatrixParams(resTy, mdescTy, subgroup_block_io,
1111-
[&]() { return emitError(); });
1110+
return IsValidMatrixOpParams(resTy, mdescTy, subgroup_block_io,
1111+
[&]() { return emitError(); });
11121112
}
11131113

11141114
//===----------------------------------------------------------------------===//
@@ -1131,8 +1131,8 @@ LogicalResult StoreMatrixOp::verify() {
11311131
auto dataTy = dyn_cast<VectorType>(getData().getType());
11321132
UnitAttr subgroup_block_io = getSubgroupBlockIoAttr();
11331133
MemDescType mdescTy = getMemDesc().getType();
1134-
return IsValidStoreMatrixParams(dataTy, mdescTy, subgroup_block_io,
1135-
[&]() { return emitError(); });
1134+
return IsValidMatrixOpParams(dataTy, mdescTy, subgroup_block_io,
1135+
[&]() { return emitError(); });
11361136
}
11371137

11381138
//===----------------------------------------------------------------------===//

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
1111
//CHECK: %[[TID:.*]] = gpu.thread_id x
1212
//CHECK: %[[C1:.*]] = arith.constant 1 : index
1313
//CHECK: %[[MUL1:.*]] = arith.muli %[[TID]], %[[C1]] : index
14-
//CHECK: %[[C4:.*]] = arith.constant 4 : i64
15-
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i64
14+
//CHECK: %[[C4:.*]] = arith.constant 4 : i32
15+
//CHECK: %[[MUL2:.*]] = arith.muli {{.*}}, %[[C4]] : i32
1616
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> f32
1717

1818
%tid_x = gpu.thread_id x
@@ -80,7 +80,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
8080
%c19 = arith.constant 19: index
8181

8282
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
83-
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
83+
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
8484
//CHECK: %[[c16:.*]] = arith.constant 16 : index
8585
//CHECK: %[[offsetx_0:.*]] = arith.divsi %[[c19]], %[[c16]] : index
8686
//CHECK: %[[offsetx_1:.*]] = arith.remsi %[[c19]], %[[c16]] : index
@@ -164,7 +164,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
164164
%c48 = arith.constant 48 : index
165165

166166
//CHECK: %[[intptr:.*]] = memref.extract_aligned_pointer_as_index %[[view]] : memref<2048xf16, 3> -> index
167-
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i64
167+
//CHECK: %[[basePtrI64:.*]] = arith.index_castui %[[intptr]] : index to i32
168168
//CHECK: %[[offset0:.*]] = arith.divsi %[[c16]], %[[c16]] : index
169169
//CHECK: %[[offset1:.*]] = arith.remsi %[[c16]], %[[c16]] : index
170170
//CHECK: %[[offset2:.*]] = arith.divsi %[[c48]], %[[c16]] : index
@@ -180,11 +180,11 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
180180
//CHECK: %[[c1:.*]] = arith.constant 1 : index
181181
//CHECK: %[[mul3:.*]] = arith.muli %[[offset3]], %[[c1]] : index
182182
//CHECK: %[[linearOffset:.*]] = arith.addi %[[mul3]], %[[add2]] : index
183-
//CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i64
184-
//CHECK: %[[c2:.*]] = arith.constant 2 : i64
185-
//CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i64
186-
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i64
187-
//CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i64 to !llvm.ptr<3>
183+
//CHECK: %[[linearOffsetI64:.*]] = arith.index_castui %[[linearOffset]] : index to i32
184+
//CHECK: %[[c2:.*]] = arith.constant 2 : i32
185+
//CHECK: %[[byteOffset:.*]] = arith.muli %[[linearOffsetI64]], %[[c2]] : i32
186+
//CHECK: %[[finalPtr:.*]] = arith.addi %[[basePtrI64]], %[[byteOffset]] : i32
187+
//CHECK: %[[ptr:.*]] = llvm.inttoptr %[[finalPtr]] : i32 to !llvm.ptr<3>
188188
//CHECK: %[[loadedI16:.*]] = xevm.blockload %[[ptr]] : (!llvm.ptr<3>) -> vector<8xi16>
189189
//CHECK: %[[loaded:.*]] = vector.bitcast %[[loadedI16]] : vector<8xi16> to vector<8xf16>
190190

0 commit comments

Comments
 (0)