@@ -365,10 +365,11 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
365365
366366// Add a builder that creates
367367// offset * elemByteSize + baseAddr
368- static Value addOffset (ConversionPatternRewriter &rewriter, Location loc,
369- Value baseAddr, Value offset, int64_t elemByteSize) {
368+ static Value addOffsetToBaseAddr (ConversionPatternRewriter &rewriter,
369+ Location loc, Value baseAddr, Value offset,
370+ int64_t elemByteSize) {
370371 Value byteSize = arith::ConstantIntOp::create (
371- rewriter, loc, rewriter. getI64Type (), elemByteSize);
372+ rewriter, loc, baseAddr. getType (), elemByteSize);
372373 Value byteOffset = arith::MulIOp::create (rewriter, loc, offset, byteSize);
373374 Value newAddr = arith::AddIOp::create (rewriter, loc, baseAddr, byteOffset);
374375 return newAddr;
@@ -443,7 +444,8 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
443444 // If offset is provided, we add them to the base pointer.
444445 // Offset is in number of elements, we need to multiply by
445446 // element byte size.
446- basePtrI64 = addOffset (rewriter, loc, basePtrI64, offset, elemByteSize);
447+ basePtrI64 =
448+ addOffsetToBaseAddr (rewriter, loc, basePtrI64, offset, elemByteSize);
447449 }
448450 // Convert base pointer (i64) to LLVM pointer type.
449451 Value basePtrLLVM =
@@ -516,7 +518,7 @@ class CreateMemDescOpPattern final
516518 LogicalResult
517519 matchAndRewrite (xegpu::CreateMemDescOp op, OpAdaptor adaptor,
518520 ConversionPatternRewriter &rewriter) const override {
519- TypedValue<MemRefType> src = op. getSource ();
521+
520522 auto resTy = cast<xegpu::MemDescType>(op.getResult ().getType ());
521523
522524 // Create the result MemRefType with the same shape, element type, and
@@ -525,7 +527,7 @@ class CreateMemDescOpPattern final
525527
526528 Value zero = arith::ConstantIndexOp::create (rewriter, op.getLoc (), 0 );
527529 auto viewOp = memref::ViewOp::create (rewriter, op.getLoc (), newResTy,
528- Value (src ), zero, ValueRange ());
530+ op. getSource ( ), zero, ValueRange ());
529531 rewriter.replaceOp (op, viewOp);
530532 return success ();
531533 }
@@ -587,88 +589,74 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
587589 Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create (
588590 rewriter, loc, basePtrStruct);
589591
590- // Convert base pointer (ptr) to i64
591- Value basePtrI64 = arith::IndexCastUIOp::create (
592- rewriter, loc, rewriter.getI64Type (), basePtrLLVM);
592+ // Convert base pointer (ptr) to i32
593+ Value basePtrI32 = arith::IndexCastUIOp::create (
594+ rewriter, loc, rewriter.getI32Type (), basePtrLLVM);
593595
594596 Value linearOffset = mdescTy.getLinearOffsets (rewriter, loc, offsets);
595597 linearOffset = arith::IndexCastUIOp::create (
596- rewriter, loc, rewriter.getI64Type (), linearOffset);
597- basePtrI64 =
598- addOffset (rewriter, loc, basePtrI64, linearOffset, elemByteSize);
598+ rewriter, loc, rewriter.getI32Type (), linearOffset);
599+ basePtrI32 = addOffsetToBaseAddr (rewriter, loc, basePtrI32, linearOffset,
600+ elemByteSize);
599601
600- // convert base pointer (i64 ) to LLVM pointer type
602+ // convert base pointer (i32 ) to LLVM pointer type
601603 basePtrLLVM =
602- LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtrI64 );
604+ LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtrI32 );
603605
604- // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
605- // operation. LLVM load/store does not support vector of size 1, so we need
606- // to handle this case separately.
607- if (valOrResVecTy.getNumElements () == 1 ) {
608- Type scalarTy = valOrResVecTy.getElementType ();
609- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
610- Value loadOp =
611- LLVM::LoadOp::create (rewriter, loc, scalarTy, basePtrLLVM);
612- rewriter.replaceOp (op, loadOp);
613- } else {
614- LLVM::StoreOp::create (rewriter, loc, adaptor.getData (), basePtrLLVM);
615- rewriter.eraseOp (op);
616- }
617- return success ();
618- } else {
606+ if (op.getSubgroupBlockIoAttr ()) {
619607 // if the attribute 'subgroup_block_io' is set to true, it lowers to
620608 // xevm.blockload
621- auto subgroupBlockIoAttr = op.getSubgroupBlockIoAttr ();
622- bool subgroup_block_io = static_cast <bool >(subgroupBlockIoAttr);
623-
624- // BlockLoadOp only supports integer types, so we need to bitcast
625- // Get integer type with matching bit width
626- Type elemTy = valOrResVecTy.getElementType ();
627- int64_t bitWidth = elemTy.getIntOrFloatBitWidth ();
628- Type intElemTy = rewriter.getIntegerType (bitWidth);
609+
610+ Type intElemTy = rewriter.getIntegerType (elemBitWidth);
629611 VectorType intVecTy =
630612 VectorType::get (valOrResVecTy.getShape (), intElemTy);
631613
632- if (subgroup_block_io) {
633- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
634- Value loadOp =
635- xevm::BlockLoadOp::create (rewriter, loc, intVecTy, basePtrLLVM);
636- if (intVecTy != valOrResVecTy) {
637- loadOp =
638- vector::BitCastOp::create (rewriter, loc, valOrResVecTy, loadOp);
639- }
640- rewriter.replaceOp (op, loadOp);
641- } else {
642- Value dataToStore = adaptor.getData ();
643- if (valOrResVecTy != intVecTy) {
644- dataToStore =
645- vector::BitCastOp::create (rewriter, loc, intVecTy, dataToStore);
646- }
647- xevm::BlockStoreOp::create (rewriter, loc, basePtrLLVM, dataToStore,
648- nullptr );
649- rewriter.eraseOp (op);
614+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
615+ Value loadOp =
616+ xevm::BlockLoadOp::create (rewriter, loc, intVecTy, basePtrLLVM);
617+ if (intVecTy != valOrResVecTy) {
618+ loadOp =
619+ vector::BitCastOp::create (rewriter, loc, valOrResVecTy, loadOp);
650620 }
621+ rewriter.replaceOp (op, loadOp);
651622 } else {
652- // if the result is 1D vector, if the vector direction is Column, then
653- // the
654- // memory descriptor should be treated as column major
655- auto chipOpt = xegpu::getChipStr (op);
656- if (!chipOpt || (*chipOpt != " pvc" && *chipOpt != " bmg" )) {
657- // the lowering only works for pvc and bmg
658- return rewriter.notifyMatchFailure (
659- op, " The lowering is specific to pvc or bmg." );
623+ Value dataToStore = adaptor.getData ();
624+ if (valOrResVecTy != intVecTy) {
625+ dataToStore =
626+ vector::BitCastOp::create (rewriter, loc, intVecTy, dataToStore);
660627 }
628+ xevm::BlockStoreOp::create (rewriter, loc, basePtrLLVM, dataToStore,
629+ nullptr );
630+ rewriter.eraseOp (op);
631+ }
632+ return success ();
633+ }
661634
662- if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
663- Value loadOp =
664- LLVM::LoadOp::create (rewriter, loc, valOrResVecTy, basePtrLLVM);
665- rewriter.replaceOp (op, loadOp);
666- } else {
667- LLVM::StoreOp::create (rewriter, loc, adaptor.getData (), basePtrLLVM);
668- rewriter.eraseOp (op);
669- }
635+ if (valOrResVecTy.getNumElements () >= 1 ) {
636+ auto chipOpt = xegpu::getChipStr (op);
637+ if (!chipOpt || (*chipOpt != " pvc" && *chipOpt != " bmg" )) {
638+ // the lowering for chunk load only works for pvc and bmg
639+ return rewriter.notifyMatchFailure (
640+ op, " The lowering is specific to pvc or bmg." );
670641 }
671642 }
643+
644+ if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
645+ // if the size of valOrResVecTy is 1, it lowers to a scalar load/store
646+ // operation. LLVM load/store does not support vector of size 1, so we
647+ // need to handle this case separately.
648+ auto scalarTy = valOrResVecTy.getElementType ();
649+ LLVM::LoadOp loadOp;
650+ if (valOrResVecTy.getNumElements () == 1 )
651+ loadOp = LLVM::LoadOp::create (rewriter, loc, scalarTy, basePtrLLVM);
652+ else
653+ loadOp =
654+ LLVM::LoadOp::create (rewriter, loc, valOrResVecTy, basePtrLLVM);
655+ rewriter.replaceOp (op, loadOp);
656+ } else {
657+ LLVM::StoreOp::create (rewriter, loc, adaptor.getData (), basePtrLLVM);
658+ rewriter.eraseOp (op);
659+ }
672660 return success ();
673661 }
674662};
@@ -715,8 +703,8 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
715703 op, " Expected element type bit width to be multiple of 8." );
716704 elemByteSize = elemBitWidth / 8 ;
717705 }
718- basePtrI64 =
719- addOffset (rewriter, loc, basePtrI64, offsets, elemByteSize);
706+ basePtrI64 = addOffsetToBaseAddr (rewriter, loc, basePtrI64, offsets,
707+ elemByteSize);
720708 }
721709 }
722710 // Default memory space is global.
0 commit comments