@@ -371,7 +371,8 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
371371 Value baseAddr, Value offset, int64_t elemByteSize) {
372372 Value byteSize = arith::ConstantIntOp::create (
373373 rewriter, loc, rewriter.getI64Type (), elemByteSize);
374- offset = arith::IndexCastUIOp::create (rewriter, loc, rewriter.getI64Type (), offset);
374+ offset = arith::IndexCastUIOp::create (rewriter, loc, rewriter.getI64Type (),
375+ offset);
375376 Value byteOffset = arith::MulIOp::create (rewriter, loc, offset, byteSize);
376377 Value newAddr = arith::AddIOp::create (rewriter, loc, baseAddr, byteOffset);
377378 return newAddr;
@@ -513,29 +514,36 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
513514// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than
514515// 32 bits will be converted to 32 bits.
515516class CreateMemDescOpPattern final
516- : public OpConversionPattern<xegpu::CreateMemDescOp> {
517+ : public OpConversionPattern<xegpu::CreateMemDescOp> {
517518public:
518519 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
519520 LogicalResult
520521 matchAndRewrite (xegpu::CreateMemDescOp op, OpAdaptor adaptor,
521- ConversionPatternRewriter &rewriter) const override {
522- // DEBUG: Print operation and types
523- LLVM_DEBUG (llvm::dbgs () << " [XeGPUToXeVM] Lowering CreateMemDescOp: " << op << " \n " );
524- TypedValue<MemRefType> src = op.getSource ();
525- auto resTy = cast<xegpu::MemDescType>(op.getResult ().getType ());
526-
527- // Create the result MemRefType with the same shape, element type, and memory space
528- auto newResTy = getTypeConverter ()->convertType <MemRefType>(resTy);
529-
530- LLVM_DEBUG (llvm::dbgs () << " [XeGPUToXeVM] Source MemRefType: " << src.getType () << " \n " );
531- LLVM_DEBUG (llvm::dbgs () << " [XeGPUToXeVM] Result MemDescType: " << resTy << " \n " );
532- LLVM_DEBUG (llvm::dbgs () << " [XeGPUToXeVM] Converted MemRefType: " << newResTy << " \n " );
533- Value zero = arith::ConstantIndexOp::create (rewriter, op.getLoc (), 0 );
534- auto viewOp = memref::ViewOp::create (rewriter, op.getLoc (), newResTy, Value (src), zero,
535- ValueRange ());
536- rewriter.replaceOp (op, viewOp);
537- LLVM_DEBUG (llvm::dbgs () << " [XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n " );
538- return success ();
522+ ConversionPatternRewriter &rewriter) const override {
523+ // DEBUG: Print operation and types
524+ LLVM_DEBUG (llvm::dbgs ()
525+ << " [XeGPUToXeVM] Lowering CreateMemDescOp: " << op << " \n " );
526+ TypedValue<MemRefType> src = op.getSource ();
527+ auto resTy = cast<xegpu::MemDescType>(op.getResult ().getType ());
528+
529+ // Create the result MemRefType with the same shape, element type, and
530+ // memory space
531+ auto newResTy = getTypeConverter ()->convertType <MemRefType>(resTy);
532+
533+ LLVM_DEBUG (llvm::dbgs ()
534+ << " [XeGPUToXeVM] Source MemRefType: " << src.getType () << " \n " );
535+ LLVM_DEBUG (llvm::dbgs ()
536+ << " [XeGPUToXeVM] Result MemDescType: " << resTy << " \n " );
537+ LLVM_DEBUG (llvm::dbgs ()
538+ << " [XeGPUToXeVM] Converted MemRefType: " << newResTy << " \n " );
539+ Value zero = arith::ConstantIndexOp::create (rewriter, op.getLoc (), 0 );
540+ auto viewOp = memref::ViewOp::create (rewriter, op.getLoc (), newResTy,
541+ Value (src), zero, ValueRange ());
542+ rewriter.replaceOp (op, viewOp);
543+ LLVM_DEBUG (
544+ llvm::dbgs ()
545+ << " [XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n " );
546+ return success ();
539547 }
540548};
541549
@@ -551,7 +559,6 @@ class MemDescSubviewOpPattern final
551559 }
552560};
553561
554-
555562template <typename OpType,
556563 typename = std::enable_if_t <llvm::is_one_of<
557564 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
@@ -577,7 +584,8 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
577584 data = adaptor.getData ();
578585 VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType ());
579586
580- int64_t elemBitWidth = valOrResVecTy.getElementType ().getIntOrFloatBitWidth ();
587+ int64_t elemBitWidth =
588+ valOrResVecTy.getElementType ().getIntOrFloatBitWidth ();
581589 // Element type must be multiple of 8 bits.
582590 if (elemBitWidth % 8 != 0 )
583591 return rewriter.notifyMatchFailure (
@@ -589,14 +597,17 @@ class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> {
589597 ctxt, getNumericXeVMAddrSpace (xegpu::MemorySpace::SLM));
590598
591599 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType ());
592-
593- Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create (rewriter, loc, basePtrStruct);
600+
601+ Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create (
602+ rewriter, loc, basePtrStruct);
594603
595604 // Convert base pointer (ptr) to i64
596- Value basePtrI64 = arith::IndexCastUIOp::create (rewriter, loc, rewriter.getI64Type (), basePtrLLVM);
605+ Value basePtrI64 = arith::IndexCastUIOp::create (
606+ rewriter, loc, rewriter.getI64Type (), basePtrLLVM);
597607
598608 Value linearOffset = mdescTy.getLinearOffsets (rewriter, loc, offsets);
599- basePtrI64 = addOffset (rewriter, loc, basePtrI64, linearOffset, elemByteSize);
609+ basePtrI64 =
610+ addOffset (rewriter, loc, basePtrI64, linearOffset, elemByteSize);
600611
601612 // convert base pointer (i64) to LLVM pointer type
602613 basePtrLLVM =
0 commit comments