Skip to content

Commit 9f9744c

Browse files
committed
bug fixes
1 parent 446b951 commit 9f9744c

File tree

4 files changed

+30
-109
lines changed

4 files changed

+30
-109
lines changed

mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@
3333

3434
#include <numeric>
3535

36-
#define DEBUG_TYPE "xegpu-to-xevm"
37-
3836
namespace mlir {
3937
#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
4038
#include "mlir/Conversion/Passes.h.inc"
@@ -519,29 +517,17 @@ class CreateMemDescOpPattern final
519517
LogicalResult
520518
matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
521519
ConversionPatternRewriter &rewriter) const override {
522-
// DEBUG: Print operation and types
523-
LLVM_DEBUG(llvm::dbgs()
524-
<< "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n");
525520
TypedValue<MemRefType> src = op.getSource();
526521
auto resTy = cast<xegpu::MemDescType>(op.getResult().getType());
527522

528523
// Create the result MemRefType with the same shape, element type, and
529524
// memory space
530525
auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy);
531526

532-
LLVM_DEBUG(llvm::dbgs()
533-
<< "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n");
534-
LLVM_DEBUG(llvm::dbgs()
535-
<< "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n");
536-
LLVM_DEBUG(llvm::dbgs()
537-
<< "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n");
538527
Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0);
539528
auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy,
540529
Value(src), zero, ValueRange());
541530
rewriter.replaceOp(op, viewOp);
542-
LLVM_DEBUG(
543-
llvm::dbgs()
544-
<< "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n");
545531
return success();
546532
}
547533
};
@@ -635,16 +621,33 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
635621
// if the attribute 'subgroup_block_io' is set to true, it lowers to
636622
// xevm.blockload
637623
auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr();
638-
bool subgroup_block_io =
639-
subgroupBlockIoAttr && cast<BoolAttr>(subgroupBlockIoAttr).getValue();
624+
bool subgroup_block_io = static_cast<bool>(subgroupBlockIoAttr);
625+
626+
// BlockLoadOp only supports integer types, so we need to bitcast
627+
// Get integer type with matching bit width
628+
Type elemTy = valOrResVecTy.getElementType();
629+
int64_t bitWidth = elemTy.getIntOrFloatBitWidth();
630+
Type intElemTy = rewriter.getIntegerType(bitWidth);
631+
VectorType intVecTy =
632+
VectorType::get(valOrResVecTy.getShape(), intElemTy);
633+
640634
if (subgroup_block_io) {
641635
if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
642-
Value loadOp = xevm::BlockLoadOp::create(rewriter, loc, valOrResVecTy,
643-
basePtrLLVM);
636+
Value loadOp =
637+
xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
638+
if (intVecTy != valOrResVecTy) {
639+
loadOp =
640+
vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
641+
}
644642
rewriter.replaceOp(op, loadOp);
645643
} else {
646-
xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM,
647-
adaptor.getData(), nullptr);
644+
Value dataToStore = adaptor.getData();
645+
if (valOrResVecTy != intVecTy) {
646+
dataToStore =
647+
vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
648+
}
649+
xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
650+
nullptr);
648651
rewriter.eraseOp(op);
649652
}
650653
} else {

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 5 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ void XeGPUDialect::initialize() {
3737
>();
3838
}
3939

40-
#define DEBUG_TYPE "xegpu"
41-
4240
/// Generates instructions to compute offsets for a subgroup identified by
4341
/// its multidimensional indices (sgId), using the specified subgroup layout
4442
/// (sgLayout), subgroup data dimensions (sizePerSg), and the overall data
@@ -788,56 +786,21 @@ SmallVector<int64_t> MemDescType::getStrides() {
788786
strides.push_back(cast<IntegerAttr>(attr).getInt());
789787
}
790788

791-
llvm::dbgs() << "DEBUG: matrixShape = [";
792-
for (size_t i = 0; i < matrixShape.size(); ++i) {
793-
llvm::dbgs() << matrixShape[i];
794-
if (i < matrixShape.size() - 1)
795-
llvm::dbgs() << ", ";
796-
}
797-
llvm::dbgs() << "]\n";
798-
799-
llvm::dbgs() << "DEBUG: strides = [";
800-
for (size_t i = 0; i < strides.size(); ++i) {
801-
llvm::dbgs() << strides[i];
802-
if (i < strides.size() - 1)
803-
llvm::dbgs() << ", ";
804-
}
805-
llvm::dbgs() << "]\n";
806-
807789
SmallVector<int64_t> innerBlkShape = getBlockSize();
808-
llvm::dbgs() << "DEBUG: innerBlkShape = [";
809-
for (size_t i = 0; i < innerBlkShape.size(); ++i) {
810-
llvm::dbgs() << innerBlkShape[i];
811-
if (i < innerBlkShape.size() - 1)
812-
llvm::dbgs() << ", ";
813-
}
814-
llvm::dbgs() << "]\n";
815790

816791
// get perm from FCD to LCD
817792
// perm[i] = the dim with i-th smallest stride
818793
SmallVector<int, 4> perm =
819794
llvm::to_vector<4>(llvm::seq<int>(0, strides.size()));
820795
llvm::sort(perm, [&](int a, int b) { return strides[a] < strides[b]; });
821796

822-
llvm::dbgs() << "DEBUG: perm = [";
823-
for (size_t i = 0; i < perm.size(); ++i) {
824-
llvm::dbgs() << perm[i];
825-
if (i < perm.size() - 1)
826-
llvm::dbgs() << ", ";
827-
}
828-
llvm::dbgs() << "]\n";
829-
830797
assert(strides[perm[0]] == 1 && "inner most dim must have stride 1");
831798

832-
SmallVector<int64_t> innerBlkStride = computeStrides(innerBlkShape);
833-
834-
llvm::dbgs() << "DEBUG: innerBlkStride = [";
835-
for (size_t i = 0; i < innerBlkStride.size(); ++i) {
836-
llvm::dbgs() << innerBlkStride[i];
837-
if (i < innerBlkStride.size() - 1)
838-
llvm::dbgs() << ", ";
839-
}
840-
llvm::dbgs() << "]\n";
799+
SmallVector<int64_t> innerBlkStride(innerBlkShape.size());
800+
innerBlkStride[perm[0]] = 1;
801+
for (size_t i = 1; i < perm.size(); ++i)
802+
innerBlkStride[perm[i]] =
803+
innerBlkStride[perm[i - 1]] * innerBlkShape[perm[i - 1]];
841804

842805
// compute the original matrix shape using the stride info
843806
// and compute the number of blocks in each dimension
@@ -850,56 +813,22 @@ SmallVector<int64_t> MemDescType::getStrides() {
850813
BlkShapeOrig[perm[i]] = matrixShapeOrig[perm[i]] / innerBlkShape[perm[i]];
851814
}
852815

853-
llvm::dbgs() << "DEBUG: matrixShapeOrig = [";
854-
for (size_t i = 0; i < matrixShapeOrig.size(); ++i) {
855-
llvm::dbgs() << matrixShapeOrig[i];
856-
if (i < matrixShapeOrig.size() - 1)
857-
llvm::dbgs() << ", ";
858-
}
859-
llvm::dbgs() << "]\n";
860-
861-
llvm::dbgs() << "DEBUG: BlkShapeOrig = [";
862-
for (size_t i = 0; i < BlkShapeOrig.size(); ++i) {
863-
llvm::dbgs() << BlkShapeOrig[i];
864-
if (i < BlkShapeOrig.size() - 1)
865-
llvm::dbgs() << ", ";
866-
}
867-
llvm::dbgs() << "]\n";
868-
869816
int64_t innerBlkSize = 1;
870817
for (auto s : innerBlkShape)
871818
innerBlkSize *= s;
872819

873-
llvm::dbgs() << "DEBUG: innerBlkSize = " << innerBlkSize << "\n";
874-
875820
SmallVector<int64_t> outerBlkStride(matrixShape.size());
876821
outerBlkStride[perm[0]] = innerBlkSize;
877822
for (size_t i = 0; i < perm.size() - 1; ++i) {
878823
outerBlkStride[perm[i + 1]] =
879824
outerBlkStride[perm[i]] * BlkShapeOrig[perm[i]];
880825
}
881826

882-
llvm::dbgs() << "DEBUG: outerBlkStride = [";
883-
for (size_t i = 0; i < outerBlkStride.size(); ++i) {
884-
llvm::dbgs() << outerBlkStride[i];
885-
if (i < outerBlkStride.size() - 1)
886-
llvm::dbgs() << ", ";
887-
}
888-
llvm::dbgs() << "]\n";
889-
890827
// combine the inner and outer strides
891828
SmallVector<int64_t> blockedStrides;
892829
blockedStrides.append(outerBlkStride.begin(), outerBlkStride.end());
893830
blockedStrides.append(innerBlkStride.begin(), innerBlkStride.end());
894831

895-
llvm::dbgs() << "DEBUG: blockedStrides = [";
896-
for (size_t i = 0; i < blockedStrides.size(); ++i) {
897-
llvm::dbgs() << blockedStrides[i];
898-
if (i < blockedStrides.size() - 1)
899-
llvm::dbgs() << ", ";
900-
}
901-
llvm::dbgs() << "]\n";
902-
903832
return blockedStrides;
904833
}
905834

@@ -911,12 +840,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
911840
SmallVector<int64_t> blockShape = getBlockSize();
912841
SmallVector<int64_t> strides = getStrides();
913842

914-
LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blockShape=[";
915-
llvm::interleaveComma(blockShape, llvm::dbgs());
916-
llvm::dbgs() << "], strides=[";
917-
llvm::interleaveComma(strides, llvm::dbgs());
918-
llvm::dbgs() << "]\n");
919-
920843
// blockshape equal to matrixshape means no blocking
921844
if (llvm::equal(blockShape, matrixShape)) {
922845
// remove the outer dims from strides
@@ -937,8 +860,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
937860
blockedOffsets.append(rems.begin(), rems.end());
938861

939862
offsets = blockedOffsets;
940-
LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: blocked offsets size="
941-
<< offsets.size() << "\n");
942863
}
943864

944865
// Start with initial value as matrix descriptor's base offset.
@@ -949,9 +870,6 @@ Value MemDescType::getLinearOffsets(OpBuilder &builder, Location loc,
949870
linearOffset = arith::AddIOp::create(builder, loc, mulVal, linearOffset);
950871
}
951872

952-
LLVM_DEBUG(llvm::dbgs() << "getLinearOffsets: final linearOffset="
953-
<< linearOffset << "\n");
954-
955873
return linearOffset;
956874
}
957875

mlir/test/Conversion/XeGPUToXeVM/loadstore_matrix.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ gpu.module @test_kernel [#xevm.target<chip = "pvc">] {
5555
//CHECK-LABEL: load_store_matrix_5
5656
gpu.func @load_store_matrix_5(%arg0: memref<4096xi8, 3>) -> vector<8xf16> {
5757
%0 = xegpu.create_mem_desc %arg0 : memref<4096xi8, 3> -> !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>
58-
//CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
58+
//CHECK: xevm.blockload {{.*}} : (!llvm.ptr<3>) -> vector<8xi16>
5959
%c16 = arith.constant 16 : index
6060
%c48 = arith.constant 48 : index
6161
%1 = xegpu.load_matrix %0[%c16, %c48] {subgroup_block_io}: !xegpu.mem_desc<32x64xf16, #xegpu.mem_layout<block = [16, 16]>>, index, index -> vector<8xf16>

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,7 @@ func.func @load_mem_desc_mismatch_element_type(%arg0: !xegpu.mem_desc<16x64xf16>
858858

859859
// -----
860860
func.func @load_mem_desc_invalid_result_size(%arg0: !xegpu.mem_desc<16x64xf16>) {
861-
// expected-error@+1 {{result shape must not exceed mem_desc shape}}
861+
// expected-error@+1 {{data shape must not exceed mem_desc shape}}
862862
%data = xegpu.load_matrix %arg0[8, 8]: !xegpu.mem_desc<16x64xf16> -> vector<32x16xf16>
863863
return
864864
}

0 commit comments

Comments
 (0)