77// ===----------------------------------------------------------------------===//
88
99#include " mlir/Conversion/XeGPUToXeVM/XeGPUToXeVM.h"
10+ #include " mlir/Dialect/LLVMIR/LLVMTypes.h"
1011#include " mlir/Dialect/LLVMIR/XeVMDialect.h"
1112
1213#include " mlir/Conversion/LLVMCommon/Pattern.h"
@@ -426,18 +427,6 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
426427 }
427428};
428429
429- template <
430- typename OpType,
431- typename = std::enable_if_t <llvm::is_one_of<
432- OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp, xegpu::CreateDescOp,
433- xegpu::UpdateOffsetOp, xegpu::PrefetchOp>::value>>
434- int64_t getElemByteSize (OpType op) {
435- // Get the element byte size from the tensor descriptor.
436- auto elemBitWidth =
437- op.getTensorDesc ().getType ().getElementType ().getIntOrFloatBitWidth ();
438- return elemBitWidth / 8 ;
439- }
440-
441430// Add a builder that creates
442431// offset * elemByteSize + baseAddr
443432auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
@@ -456,23 +445,23 @@ class CreateDescToXeVMPattern
456445 LogicalResult
457446 matchAndRewrite (xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
458447 ConversionPatternRewriter &rewriter) const override {
448+ auto eTy = op.getTensorDescType ().getElementType ();
449+ if (eTy.getIntOrFloatBitWidth () % 8 != 0 ) {
450+ return rewriter.notifyMatchFailure (op,
451+ " Expected element type bit width to be multiple of 8." );
452+ }
459453 auto loc = op.getLoc ();
454+ // offsets are provided as scalar i64 by type converter.
460455 auto offsets = adaptor.getOffsets ();
461- // Source type can be a 1D memref or ui64
462- // Using "op" instead of "adaptor" since we want to access memref type
463- // instead of LLVM struct type.
464- auto memrefTy = dyn_cast<MemRefType>(op.getSource ().getType ());
465- Value subGroupAddr;
466- if (memrefTy) {
467- subGroupAddr = memref::ExtractAlignedPointerAsIndexOp::create (
468- rewriter, loc, op.getSource ());
469- subGroupAddr = arith::IndexCastUIOp::create (
470- rewriter, loc, rewriter.getI64Type (), subGroupAddr);
471- } else {
472- subGroupAddr = adaptor.getSource ();
473- }
456+ // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
457+ // But type converter will convert them to integer types.
458+ Value addr = adaptor.getSource ();
459+ // ui32 or i32 are passed as i32 so they need to be casted to i64.
460+ if (addr.getType () != rewriter.getI64Type ())
461+ addr = arith::IndexCastUIOp::create (
462+ rewriter, loc, rewriter.getI64Type (), addr);
474463 auto laneAddr =
475- addOffset (rewriter, loc, subGroupAddr , offsets, getElemByteSize (op));
464+ addOffset (rewriter, loc, addr , offsets, getElemByteSize (op));
476465 rewriter.replaceOp (op, laneAddr);
477466 return success ();
478467 }
@@ -485,11 +474,18 @@ class UpdateOffsetToXeVMPattern
485474 matchAndRewrite (xegpu::UpdateOffsetOp op,
486475 xegpu::UpdateOffsetOp::Adaptor adaptor,
487476 ConversionPatternRewriter &rewriter) const override {
477+ auto eTy = op.getTensorDescType ().getElementType ();
478+ if (eTy.getIntOrFloatBitWidth () % 8 != 0 ) {
479+ return rewriter.notifyMatchFailure (op,
480+ " Expected element type bit width to be multiple of 8." );
481+ }
488482 auto loc = op.getLoc ();
489- Value newOffsetForLane =
483+ // scatter descriptor is provided as scalar i64 by type converter.
484+ // offsets are provided as scalar i64 by type converter.
485+ Value newOffset =
490486 addOffset (rewriter, loc, adaptor.getTensorDesc (), adaptor.getOffsets (),
491487 getElemByteSize (op));
492- rewriter.replaceOp (op, newOffsetForLane );
488+ rewriter.replaceOp (op, newOffset );
493489 return success ();
494490 }
495491};
@@ -505,19 +501,38 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
505501 auto loc = op.getLoc ();
506502 auto ctxt = rewriter.getContext ();
507503 auto tdescTy = op.getTensorDescType ();
508- auto ptrTypeLLVM = LLVM::LLVMPointerType::get (
509- ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
504+ LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get (
505+ ctxt, getNumericXeVMAddrSpace (xegpu::MemorySpace::Global));
506+ if (tdescTy)
507+ ptrTypeLLVM = LLVM::LLVMPointerType::get (
508+ ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
510509 Value basePtrI64;
511510 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
512511 basePtrI64 = adaptor.getSource ();
512+ if (auto memRefTy = dyn_cast<MemRefType>(op.getSource ().getType ())) {
513+ auto addrSpace = memRefTy.getMemorySpaceAsInt ();
514+ if (addrSpace != 0 )
515+ ptrTypeLLVM = LLVM::LLVMPointerType::get (ctxt, addrSpace);
516+ }
513517 } else {
514518 basePtrI64 = adaptor.getDest ();
519+ if (auto memRefTy = dyn_cast<MemRefType>(op.getDest ().getType ())) {
520+ auto addrSpace = memRefTy.getMemorySpaceAsInt ();
521+ if (addrSpace != 0 )
522+ ptrTypeLLVM = LLVM::LLVMPointerType::get (ctxt, addrSpace);
523+ }
515524 }
525+ if (basePtrI64.getType () != rewriter.getI64Type ()) {
526+ basePtrI64 = arith::IndexCastUIOp::create (rewriter, loc, rewriter.getI64Type (),
527+ basePtrI64);
528+ }
529+ basePtrI64.dump ();
516530 Value offsets = adaptor.getOffsets ();
531+ offsets.dump ();
517532 Value mask = adaptor.getMask ();
533+ mask.dump ();
518534 if (offsets) {
519- VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType ());
520- if (offsetsVecTy) {
535+ if (dyn_cast<VectorType>(offsets.getType ())){
521536 // Offset needs be scalar.
522537 return rewriter.notifyMatchFailure (op,
523538 " Expected offsets to be a scalar." );
@@ -526,8 +541,10 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
526541 addOffset (rewriter, loc, basePtrI64, offsets, getElemByteSize (op));
527542 }
528543 }
544+ basePtrI64.dump ();
529545 Value basePtrLLVM =
530546 LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtrI64);
547+ basePtrLLVM.dump ();
531548 VectorType srcOrDstVecTy = op.getValueType ();
532549 VectorType srcOrDstFlatVecTy = VectorType::get (
533550 srcOrDstVecTy.getNumElements (), srcOrDstVecTy.getElementType ());
@@ -597,6 +614,10 @@ class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
597614 ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
598615 Value basePtrI64 = adaptor.getSource ();
599616 Value offsets = adaptor.getOffsets ();
617+ if (basePtrI64.getType () != rewriter.getI64Type ()) {
618+ basePtrI64 = arith::IndexCastUIOp::create (rewriter, loc, rewriter.getI64Type (),
619+ basePtrI64);
620+ }
600621 if (offsets) {
601622 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType ());
602623 if (offsetsVecTy) {
@@ -836,6 +857,26 @@ struct ConvertXeGPUToXeVMPass
836857 auto i32Type = IntegerType::get (&getContext (), 32 );
837858 return VectorType::get (8 , i32Type);
838859 });
860+ typeConverter.addConversion ([&](MemRefType type) -> Type {
861+ // Convert MemRefType to i64 type.
862+ return IntegerType::get (&getContext (), 64 );
863+ });
864+
865+ auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
866+ ValueRange inputs,
867+ Location loc) -> Value {
868+ if (inputs.size () != 1 )
869+ return {};
870+ auto input = inputs.front ();
871+ if (auto memrefTy = dyn_cast<MemRefType>(input.getType ())) {
872+
873+ Value addr = memref::ExtractAlignedPointerAsIndexOp::create (
874+ builder, loc, input);
875+ return arith::IndexCastUIOp::create (builder, loc, type,
876+ addr).getResult ();
877+ }
878+ return {};
879+ };
839880
840881 auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
841882 ValueRange inputs,
@@ -847,7 +888,22 @@ struct ConvertXeGPUToXeVMPass
847888 Value cast =
848889 index::CastUOp::create (builder, loc, builder.getIndexType (), input)
849890 .getResult ();
850- return arith::IndexCastOp::create (builder, loc, type, cast).getResult ();
891+ return arith::IndexCastUIOp::create (builder, loc, type, cast).getResult ();
892+ }
893+ return {};
894+ };
895+
896+ auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
897+ ValueRange inputs,
898+ Location loc) -> Value {
899+ if (inputs.size () != 1 )
900+ return {};
901+ auto input = inputs.front ();
902+ if (input.getType () == builder.getIntegerType (32 , false )) {
903+ Value cast =
904+ index::CastUOp::create (builder, loc, builder.getIndexType (), input)
905+ .getResult ();
906+ return arith::IndexCastUIOp::create (builder, loc, type, cast).getResult ();
851907 }
852908 return {};
853909 };
@@ -864,15 +920,19 @@ struct ConvertXeGPUToXeVMPass
864920 Value cast =
865921 vector::ExtractOp::create (builder, loc, input, 0 ).getResult ();
866922 if (vecTy.getElementType () == builder.getIndexType ())
867- cast = arith::IndexCastOp ::create (builder, loc, type, cast)
923+ cast = arith::IndexCastUIOp ::create (builder, loc, type, cast)
868924 .getResult ();
869925 return cast;
870926 }
871927 }
872928 return {};
873929 };
930+ typeConverter.addSourceMaterialization (memrefMaterializationCast);
874931 typeConverter.addSourceMaterialization (ui64MaterializationCast);
932+ typeConverter.addSourceMaterialization (ui32MaterializationCast);
875933 typeConverter.addSourceMaterialization (vector1DMaterializationCast);
934+ typeConverter.addTargetMaterialization (memrefMaterializationCast);
935+ typeConverter.addTargetMaterialization (ui32MaterializationCast);
876936 typeConverter.addTargetMaterialization (ui64MaterializationCast);
877937 typeConverter.addTargetMaterialization (vector1DMaterializationCast);
878938 ConversionTarget target (getContext ());
0 commit comments