|
32 | 32 |
|
33 | 33 | #include <numeric> |
34 | 34 |
|
| 35 | +#define DEBUG_TYPE "xegpu-to-xevm" |
| 36 | + |
35 | 37 | namespace mlir { |
36 | 38 | #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS |
37 | 39 | #include "mlir/Conversion/Passes.h.inc" |
@@ -60,6 +62,9 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) { |
60 | 62 | return static_cast<int>(xevm::AddrSpace::GLOBAL); |
61 | 63 | case xegpu::MemorySpace::SLM: |
62 | 64 | return static_cast<int>(xevm::AddrSpace::SHARED); |
| 65 | + default: |
| 66 | + llvm_unreachable("Unknown XeGPU memory space"); |
| 67 | + return static_cast<int>(xevm::AddrSpace::GLOBAL); |
63 | 68 | } |
64 | 69 | } |
65 | 70 |
|
@@ -366,6 +371,7 @@ static Value addOffset(ConversionPatternRewriter &rewriter, Location loc, |
366 | 371 | Value baseAddr, Value offset, int64_t elemByteSize) { |
367 | 372 | Value byteSize = arith::ConstantIntOp::create( |
368 | 373 | rewriter, loc, rewriter.getI64Type(), elemByteSize); |
| 374 | + offset = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), offset); |
369 | 375 | Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize); |
370 | 376 | Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset); |
371 | 377 | return newAddr; |
@@ -503,6 +509,113 @@ class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> { |
503 | 509 | } |
504 | 510 | }; |
505 | 511 |
|
| 512 | +// Lower xegpu::CreateMemDescOp to memref::ViewOp. Since SLM access instructions |
| 513 | +// on Xe2 and Xe3 operate on 32-bit or 64-bit units, all data types smaller than |
| 514 | +// 32 bits will be converted to 32 bits. |
| 515 | +class CreateMemDescOpPattern final |
| 516 | + : public OpConversionPattern<xegpu::CreateMemDescOp> { |
| 517 | +public: |
| 518 | + using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern; |
| 519 | + LogicalResult |
| 520 | + matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor, |
| 521 | + ConversionPatternRewriter &rewriter) const override { |
| 522 | + // DEBUG: Print operation and types |
| 523 | + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Lowering CreateMemDescOp: " << op << "\n"); |
| 524 | + TypedValue<MemRefType> src = op.getSource(); |
| 525 | + auto resTy = cast<xegpu::MemDescType>(op.getResult().getType()); |
| 526 | + |
| 527 | + // Create the result MemRefType with the same shape, element type, and memory space |
| 528 | + auto newResTy = getTypeConverter()->convertType<MemRefType>(resTy); |
| 529 | + |
| 530 | + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Source MemRefType: " << src.getType() << "\n"); |
| 531 | + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Result MemDescType: " << resTy << "\n"); |
| 532 | + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Converted MemRefType: " << newResTy << "\n"); |
| 533 | + Value zero = arith::ConstantIndexOp::create(rewriter, op.getLoc(), 0); |
| 534 | + auto viewOp = memref::ViewOp::create(rewriter, op.getLoc(), newResTy, Value(src), zero, |
| 535 | + ValueRange()); |
| 536 | + rewriter.replaceOp(op, viewOp); |
| 537 | + LLVM_DEBUG(llvm::dbgs() << "[XeGPUToXeVM] Replaced CreateMemDescOp with memref::ViewOp\n"); |
| 538 | + return success(); |
| 539 | + } |
| 540 | +}; |
| 541 | + |
| 542 | +class MemDescSubviewOpPattern final |
| 543 | + : public OpConversionPattern<xegpu::MemDescSubviewOp> { |
| 544 | +public: |
| 545 | + using OpConversionPattern<xegpu::MemDescSubviewOp>::OpConversionPattern; |
| 546 | + LogicalResult |
| 547 | + matchAndRewrite(xegpu::MemDescSubviewOp op, OpAdaptor adaptor, |
| 548 | + ConversionPatternRewriter &rewriter) const override { |
| 549 | + return rewriter.notifyMatchFailure( |
| 550 | + op, "MemDescSubviewOp are not supported on Xe2/Xe3 architecture."); |
| 551 | + } |
| 552 | +}; |
| 553 | + |
| 554 | + |
| 555 | +template <typename OpType, |
| 556 | + typename = std::enable_if_t<llvm::is_one_of< |
| 557 | + OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>> |
| 558 | +class LoadStoreMatrixToXeVMPattern : public OpConversionPattern<OpType> { |
| 559 | + using OpConversionPattern<OpType>::OpConversionPattern; |
| 560 | + LogicalResult |
| 561 | + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, |
| 562 | + ConversionPatternRewriter &rewriter) const override { |
| 563 | + |
| 564 | + SmallVector<OpFoldResult> offsets = op.getMixedOffsets(); |
| 565 | + if (offsets.empty()) |
| 566 | + return rewriter.notifyMatchFailure(op, "Expected offset to be provided."); |
| 567 | + |
| 568 | + auto loc = op.getLoc(); |
| 569 | + auto ctxt = rewriter.getContext(); |
| 570 | + Value basePtrStruct = adaptor.getMemDesc(); |
| 571 | + Value mdescVal = op.getMemDesc(); |
| 572 | + // Load result or Store value Type can be vector or scalar. |
| 573 | + Value data; |
| 574 | + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) |
| 575 | + data = op.getResult(); |
| 576 | + else |
| 577 | + data = adaptor.getData(); |
| 578 | + VectorType valOrResVecTy = dyn_cast<VectorType>(data.getType()); |
| 579 | + |
| 580 | + int64_t elemBitWidth = valOrResVecTy.getElementType().getIntOrFloatBitWidth(); |
| 581 | + // Element type must be multiple of 8 bits. |
| 582 | + if (elemBitWidth % 8 != 0) |
| 583 | + return rewriter.notifyMatchFailure( |
| 584 | + op, "Expected element type bit width to be multiple of 8."); |
| 585 | + int64_t elemByteSize = elemBitWidth / 8; |
| 586 | + |
| 587 | + // Default memory space is SLM. |
| 588 | + LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get( |
| 589 | + ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM)); |
| 590 | + |
| 591 | + auto mdescTy = cast<xegpu::MemDescType>(mdescVal.getType()); |
| 592 | + |
| 593 | + Value basePtrLLVM = memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, basePtrStruct); |
| 594 | + |
| 595 | + // Convert base pointer (ptr) to i64 |
| 596 | + Value basePtrI64 = arith::IndexCastUIOp::create(rewriter, loc, rewriter.getI64Type(), basePtrLLVM); |
| 597 | + |
| 598 | + Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets); |
| 599 | + basePtrI64 = addOffset(rewriter, loc, basePtrI64, linearOffset, elemByteSize); |
| 600 | + |
| 601 | + // convert base pointer (i64) to LLVM pointer type |
| 602 | + basePtrLLVM = |
| 603 | + LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64); |
| 604 | + |
| 605 | + if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) { |
| 606 | + |
| 607 | + Value loadOp = |
| 608 | + LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM); |
| 609 | + rewriter.replaceOp(op, loadOp); |
| 610 | + } else { |
| 611 | + auto storeOp = |
| 612 | + LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM); |
| 613 | + rewriter.eraseOp(op); |
| 614 | + } |
| 615 | + return success(); |
| 616 | + } |
| 617 | +}; |
| 618 | + |
506 | 619 | class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> { |
507 | 620 | using OpConversionPattern::OpConversionPattern; |
508 | 621 | LogicalResult |
@@ -785,6 +898,13 @@ struct ConvertXeGPUToXeVMPass |
785 | 898 | auto i32Type = IntegerType::get(&getContext(), 32); |
786 | 899 | return VectorType::get(8, i32Type); |
787 | 900 | }); |
| 901 | + // Convert MemDescType into flattened MemRefType for SLM |
| 902 | + typeConverter.addConversion([&](xegpu::MemDescType type) -> Type { |
| 903 | + Type elemTy = type.getElementType(); |
| 904 | + int numElems = type.getNumElements(); |
| 905 | + return MemRefType::get(numElems, elemTy, AffineMap(), 3); |
| 906 | + }); |
| 907 | + |
788 | 908 | typeConverter.addConversion([&](MemRefType type) -> Type { |
789 | 909 | // Convert MemRefType to i64 type. |
790 | 910 | return IntegerType::get(&getContext(), 64); |
@@ -919,6 +1039,10 @@ void mlir::populateXeGPUToXeVMConversionPatterns( |
919 | 1039 | LoadStoreToXeVMPattern<xegpu::LoadGatherOp>, |
920 | 1040 | LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>( |
921 | 1041 | typeConverter, patterns.getContext()); |
| 1042 | + patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>, |
| 1043 | + LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>, |
| 1044 | + CreateMemDescOpPattern, MemDescSubviewOpPattern>( |
| 1045 | + typeConverter, patterns.getContext()); |
922 | 1046 | patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter, |
923 | 1047 | patterns.getContext()); |
924 | 1048 | } |
0 commit comments