@@ -446,9 +446,10 @@ class CreateDescToXeVMPattern
446446 matchAndRewrite (xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
447447 ConversionPatternRewriter &rewriter) const override {
448448 auto eTy = op.getTensorDescType ().getElementType ();
449- if (eTy.getIntOrFloatBitWidth () % 8 != 0 ) {
450- return rewriter.notifyMatchFailure (op,
451- " Expected element type bit width to be multiple of 8." );
449+ auto eBw = eTy.getIntOrFloatBitWidth ();
450+ if (eBw % 8 != 0 ) {
451+ return rewriter.notifyMatchFailure (
452+ op, " Expected element type bit width to be multiple of 8." );
452453 }
453454 auto loc = op.getLoc ();
454455 // offsets are provided as scalar i64 by type converter.
@@ -458,10 +459,8 @@ class CreateDescToXeVMPattern
458459 Value addr = adaptor.getSource ();
459460 // ui32 or i32 are passed as i32 so they need to be casted to i64.
460461 if (addr.getType () != rewriter.getI64Type ())
461- addr = arith::IndexCastUIOp::create (
462- rewriter, loc, rewriter.getI64Type (), addr);
463- auto laneAddr =
464- addOffset (rewriter, loc, addr, offsets, getElemByteSize (op));
462+ addr = arith::ExtUIOp::create (rewriter, loc, rewriter.getI64Type (), addr);
463+ auto laneAddr = addOffset (rewriter, loc, addr, offsets, eBw / 8 );
465464 rewriter.replaceOp (op, laneAddr);
466465 return success ();
467466 }
@@ -475,16 +474,16 @@ class UpdateOffsetToXeVMPattern
475474 xegpu::UpdateOffsetOp::Adaptor adaptor,
476475 ConversionPatternRewriter &rewriter) const override {
477476 auto eTy = op.getTensorDescType ().getElementType ();
478- if (eTy.getIntOrFloatBitWidth () % 8 != 0 ) {
479- return rewriter.notifyMatchFailure (op,
480- " Expected element type bit width to be multiple of 8." );
477+ auto eBw = eTy.getIntOrFloatBitWidth ();
478+ if (eBw % 8 != 0 ) {
479+ return rewriter.notifyMatchFailure (
480+ op, " Expected element type bit width to be multiple of 8." );
481481 }
482482 auto loc = op.getLoc ();
483483 // scatter descriptor is provided as scalar i64 by type converter.
484484 // offsets are provided as scalar i64 by type converter.
485- Value newOffset =
486- addOffset (rewriter, loc, adaptor.getTensorDesc (), adaptor.getOffsets (),
487- getElemByteSize (op));
485+ Value newOffset = addOffset (rewriter, loc, adaptor.getTensorDesc (),
486+ adaptor.getOffsets (), eBw / 8 );
488487 rewriter.replaceOp (op, newOffset);
489488 return success ();
490489 }
@@ -501,12 +500,35 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
501500 auto loc = op.getLoc ();
502501 auto ctxt = rewriter.getContext ();
503502 auto tdescTy = op.getTensorDescType ();
504- LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get (
505- ctxt, getNumericXeVMAddrSpace (xegpu::MemorySpace::Global));
506- if (tdescTy)
507- ptrTypeLLVM = LLVM::LLVMPointerType::get (
508- ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
509503 Value basePtrI64;
504+ // Load result or Store valye Type can be vector or scalar.
505+ Type valOrResTy;
506+ if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
507+ valOrResTy = op.getResult ().getType ();
508+ } else {
509+ valOrResTy = adaptor.getValue ().getType ();
510+ }
511+ VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
512+ bool hasScalarVal = !valOrResVecTy;
513+ int64_t elemBitWidth =
514+ hasScalarVal ? valOrResTy.getIntOrFloatBitWidth ()
515+ : valOrResVecTy.getElementType ().getIntOrFloatBitWidth ();
516+ // Element type must be multiple of 8 bits.
517+ if (elemBitWidth % 8 != 0 ) {
518+ return rewriter.notifyMatchFailure (
519+ op, " Expected element type bit width to be multiple of 8." );
520+ }
521+ int64_t elemByteSize = elemBitWidth / 8 ;
522+ // Default memory space is global.
523+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get (
524+ ctxt, getNumericXeVMAddrSpace (xegpu::MemorySpace::Global));
525+ // If tensor descriptor is available, we use its memory space.
526+ if (tdescTy) {
527+ ptrTypeLLVM = LLVM::LLVMPointerType::get (
528+ ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
529+ }
530+ // Base pointer can come from source (load) or dest (store).
531+ // If they are memrefs, we use their memory space.
510532 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
511533 basePtrI64 = adaptor.getSource ();
512534 if (auto memRefTy = dyn_cast<MemRefType>(op.getSource ().getType ())) {
@@ -522,76 +544,79 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
522544 ptrTypeLLVM = LLVM::LLVMPointerType::get (ctxt, addrSpace);
523545 }
524546 }
547+ // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
525548 if (basePtrI64.getType () != rewriter.getI64Type ()) {
526- basePtrI64 = arith::IndexCastUIOp ::create (rewriter, loc, rewriter.getI64Type (),
527- basePtrI64);
549+ basePtrI64 = arith::ExtUIOp ::create (rewriter, loc, rewriter.getI64Type (),
550+ basePtrI64);
528551 }
529- basePtrI64.dump ();
530552 Value offsets = adaptor.getOffsets ();
531- offsets.dump ();
532553 Value mask = adaptor.getMask ();
533- mask.dump ();
534554 if (offsets) {
535- if (dyn_cast<VectorType>(offsets.getType ())){
536- // Offset needs be scalar.
555+ if (dyn_cast<VectorType>(offsets.getType ())) {
556+ // Offset needs be scalar. Single element vector is converted to scalar
557+ // by type converter.
537558 return rewriter.notifyMatchFailure (op,
538559 " Expected offsets to be a scalar." );
539560 } else {
561+ // If offsets are provided, we add them to the base pointer.
562+ // Offsets are in number of elements, we need to multiply by
563+ // element byte size.
540564 basePtrI64 =
541- addOffset (rewriter, loc, basePtrI64, offsets, getElemByteSize (op) );
565+ addOffset (rewriter, loc, basePtrI64, offsets, elemByteSize );
542566 }
543567 }
544- basePtrI64. dump ();
568+ // Convert base pointer (i64) to LLVM pointer type.
545569 Value basePtrLLVM =
546570 LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtrI64);
547- basePtrLLVM.dump ();
548- VectorType srcOrDstVecTy = op.getValueType ();
549- VectorType srcOrDstFlatVecTy = VectorType::get (
550- srcOrDstVecTy.getNumElements (), srcOrDstVecTy.getElementType ());
571+
551572 Value maskForLane;
552573 VectorType maskVecTy = dyn_cast<VectorType>(mask.getType ());
553574 if (maskVecTy) {
575+ // Mask needs be scalar. Single element vector is converted to scalar by
576+ // type converter.
554577 return rewriter.notifyMatchFailure (op, " Expected mask to be a scalar." );
555- } else
578+ } else {
556579 maskForLane = mask;
580+ }
557581 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
558- scf::IfOp ifOp = scf::IfOp::create (rewriter, loc, {srcOrDstVecTy },
582+ scf::IfOp ifOp = scf::IfOp::create (rewriter, loc, {valOrResTy },
559583 maskForLane, true , true );
584+ // If mask is true,- then clause - load from memory and yield.
560585 rewriter.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
586+ if (!hasScalarVal)
587+ valOrResTy = VectorType::get ({valOrResVecTy.getNumElements ()},
588+ valOrResVecTy.getElementType ());
561589 Value loaded =
562- LLVM::LoadOp::create (rewriter, loc, srcOrDstFlatVecTy, basePtrLLVM);
590+ LLVM::LoadOp::create (rewriter, loc, valOrResTy, basePtrLLVM);
591+ // Set cache control attribute on the load operation.
563592 loaded.getDefiningOp ()->setAttr (
564593 " cache_control" , xevm::LoadCacheControlAttr::get (
565594 ctxt, translateLoadXeGPUCacheHint (
566595 op.getL1Hint (), op.getL3Hint ())));
567- if (srcOrDstVecTy != srcOrDstFlatVecTy) {
568- loaded =
569- vector::ShapeCastOp::create (rewriter, loc, srcOrDstVecTy, loaded);
570- }
571596 scf::YieldOp::create (rewriter, loc, ValueRange{loaded});
572597 rewriter.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
573- // If mask is false, we yield a vector of zeros.
574- auto eTy = srcOrDstVecTy.getElementType ();
575- loaded = arith::ConstantOp::create (
576- rewriter, loc,
577- eTy.isFloat ()
578- ? DenseElementsAttr::get (srcOrDstVecTy, FloatAttr::get (eTy, 0.0 ))
579- : DenseElementsAttr::get (srcOrDstVecTy,
580- IntegerAttr::get (eTy, 0 )));
598+ // If mask is false - else clause -yield a vector of zeros.
599+ auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType ();
600+ TypedAttr eVal;
601+ if (eTy.isFloat ())
602+ eVal = FloatAttr::get (eTy, 0.0 );
603+ else
604+ eVal = IntegerAttr::get (eTy, 0 );
605+ if (hasScalarVal)
606+ loaded = arith::ConstantOp::create (rewriter, loc, eVal);
607+ else
608+ loaded = arith::ConstantOp::create (
609+ rewriter, loc, DenseElementsAttr::get (valOrResVecTy, eVal));
581610 scf::YieldOp::create (rewriter, loc, ValueRange{loaded});
582611 rewriter.replaceOp (op, ifOp.getResult (0 ));
583612 } else {
613+ // if mask is true, perform the store.
584614 scf::IfOp ifOp = scf::IfOp::create (rewriter, loc, maskForLane, false );
585615 auto body = ifOp.getBody ();
586616 rewriter.setInsertionPointToStart (body);
587- VectorType valTy = op.getValue ().getType ();
588- Value srcFlatVec = op.getValue ();
589- if (valTy != srcOrDstFlatVecTy) {
590- srcFlatVec = vector::ShapeCastOp::create (rewriter, loc,
591- srcOrDstFlatVecTy, srcFlatVec);
592- }
593617 auto storeOp =
594- LLVM::StoreOp::create (rewriter, loc, srcFlatVec, basePtrLLVM);
618+ LLVM::StoreOp::create (rewriter, loc, adaptor.getValue (), basePtrLLVM);
619+ // Set cache control attribute on the store operation.
595620 storeOp.getOperation ()->setAttr (
596621 " cache_control" , xevm::StoreCacheControlAttr::get (
597622 ctxt, translateStoreXeGPUCacheHint (
@@ -610,27 +635,64 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
610635 auto loc = op.getLoc ();
611636 auto ctxt = rewriter.getContext ();
612637 auto tdescTy = op.getTensorDescType ();
613- auto ptrTypeLLVM = LLVM::LLVMPointerType::get (
614- ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
615638 Value basePtrI64 = adaptor.getSource ();
616- Value offsets = adaptor. getOffsets ();
639+ // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
617640 if (basePtrI64.getType () != rewriter.getI64Type ()) {
618- basePtrI64 = arith::IndexCastUIOp ::create (rewriter, loc, rewriter.getI64Type (),
619- basePtrI64);
641+ basePtrI64 = arith::ExtUIOp ::create (rewriter, loc, rewriter.getI64Type (),
642+ basePtrI64);
620643 }
644+ Value offsets = adaptor.getOffsets ();
621645 if (offsets) {
622646 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType ());
623647 if (offsetsVecTy) {
624648 // Offset needs be scalar.
625649 return rewriter.notifyMatchFailure (op,
626650 " Expected offsets to be a scalar." );
627651 } else {
652+ int64_t elemBitWidth{0 };
653+ int64_t elemByteSize;
654+ // Element byte size can come from three sources:
655+ if (tdescTy) {
656+ // If tensor descriptor is available, we use its element type to
657+ // determine element byte size.
658+ elemBitWidth = tdescTy.getElementType ().getIntOrFloatBitWidth ();
659+ } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType ())) {
660+ // If memref is available, we use its element type to
661+ // determine element byte size.
662+ elemBitWidth = memRefTy.getElementType ().getIntOrFloatBitWidth ();
663+ } else {
664+ // Otherwise, we use the provided offset byte alignment.
665+ elemByteSize = *op.getOffsetAlignByte ();
666+ }
667+ if (elemBitWidth != 0 ) {
668+ if (elemBitWidth % 8 != 0 ) {
669+ return rewriter.notifyMatchFailure (
670+ op, " Expected element type bit width to be multiple of 8." );
671+ }
672+ elemByteSize = elemBitWidth / 8 ;
673+ }
628674 basePtrI64 =
629- addOffset (rewriter, loc, basePtrI64, offsets, getElemByteSize (op) );
675+ addOffset (rewriter, loc, basePtrI64, offsets, elemByteSize );
630676 }
631677 }
678+ // Default memory space is global.
679+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get (
680+ ctxt, getNumericXeVMAddrSpace (xegpu::MemorySpace::Global));
681+ // If tensor descriptor is available, we use its memory space.
682+ if (tdescTy) {
683+ ptrTypeLLVM = LLVM::LLVMPointerType::get (
684+ ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
685+ }
686+ // If source is a memref, we use its memory space.
687+ if (auto memRefTy = dyn_cast<MemRefType>(op.getSource ().getType ())) {
688+ auto addrSpace = memRefTy.getMemorySpaceAsInt ();
689+ if (addrSpace != 0 )
690+ ptrTypeLLVM = LLVM::LLVMPointerType::get (ctxt, addrSpace);
691+ }
692+ // Convert base pointer (i64) to LLVM pointer type.
632693 Value ptrLLVM =
633694 LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtrI64);
695+ // Create the prefetch op with cache control attribute.
634696 xevm::PrefetchOp::create (
635697 rewriter, loc, ptrLLVM,
636698 xevm::LoadCacheControlAttr::get (
@@ -863,17 +925,17 @@ struct ConvertXeGPUToXeVMPass
863925 });
864926
865927 auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
866- ValueRange inputs,
867- Location loc) -> Value {
928+ ValueRange inputs,
929+ Location loc) -> Value {
868930 if (inputs.size () != 1 )
869931 return {};
870932 auto input = inputs.front ();
871933 if (auto memrefTy = dyn_cast<MemRefType>(input.getType ())) {
872934
873- Value addr = memref::ExtractAlignedPointerAsIndexOp::create (
874- builder, loc, input);
875- return arith::IndexCastUIOp::create (builder, loc, type,
876- addr) .getResult ();
935+ Value addr =
936+ memref::ExtractAlignedPointerAsIndexOp::create ( builder, loc, input);
937+ return arith::IndexCastUIOp::create (builder, loc, type, addr)
938+ .getResult ();
877939 }
878940 return {};
879941 };
@@ -888,7 +950,8 @@ struct ConvertXeGPUToXeVMPass
888950 Value cast =
889951 index::CastUOp::create (builder, loc, builder.getIndexType (), input)
890952 .getResult ();
891- return arith::IndexCastUIOp::create (builder, loc, type, cast).getResult ();
953+ return arith::IndexCastUIOp::create (builder, loc, type, cast)
954+ .getResult ();
892955 }
893956 return {};
894957 };
@@ -903,7 +966,8 @@ struct ConvertXeGPUToXeVMPass
903966 Value cast =
904967 index::CastUOp::create (builder, loc, builder.getIndexType (), input)
905968 .getResult ();
906- return arith::IndexCastUIOp::create (builder, loc, type, cast).getResult ();
969+ return arith::IndexCastUIOp::create (builder, loc, type, cast)
970+ .getResult ();
907971 }
908972 return {};
909973 };
0 commit comments