|
25 | 25 | #include "PatternTritonGPUOpToLLVM.h" |
26 | 26 | #include "mlir/Conversion/LLVMCommon/Pattern.h" |
27 | 27 | #include "mlir/Dialect/LLVMIR/NVVMDialect.h" |
| 28 | +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" |
28 | 29 | #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" |
29 | 30 |
|
30 | 31 | using namespace mlir; |
@@ -63,12 +64,55 @@ struct ClusterWaitOpConversion |
63 | 64 | return success(); |
64 | 65 | } |
65 | 66 | }; |
| 67 | + |
| 68 | +// lower MapToRemoteBufferOp |
| 69 | +struct MapToRemoteBufferOpConversion |
| 70 | + : public ConvertOpToLLVMPattern<triton::nvidia_gpu::MapToRemoteBufferOp> { |
| 71 | + using ConvertOpToLLVMPattern< |
| 72 | + triton::nvidia_gpu::MapToRemoteBufferOp>::ConvertOpToLLVMPattern; |
| 73 | + |
| 74 | + LogicalResult |
| 75 | + matchAndRewrite(triton::nvidia_gpu::MapToRemoteBufferOp op, OpAdaptor adaptor, |
| 76 | + ConversionPatternRewriter &rewriter) const override { |
| 77 | + auto loc = op.getLoc(); |
| 78 | + auto srcSmemObj = LLVM::getSharedMemoryObjectFromStruct( |
| 79 | + loc, adaptor.getSrc(), |
| 80 | + typeConverter->convertType(op.getSrc().getType().getElementType()), |
| 81 | + rewriter); |
| 82 | + auto srcSmemPtr = srcSmemObj.getBase(); |
| 83 | + |
| 84 | + auto ptrTy = cast<LLVM::LLVMPointerType>(srcSmemPtr.getType()); |
| 85 | + assert(ptrTy.getAddressSpace() == 3 && |
| 86 | + "Invalid src llvm addr space for MapToRemoteBufferOp"); |
| 87 | + |
| 88 | + // The result pointer is referring to a memory buffer living in a CTA |
| 89 | + // cluster, so it has a different memory space. NVVM::MapaOp verifies its |
| 90 | + // src and result ptr type, so we need to construct the result ptr type |
| 91 | + // from typeConverter output here |
| 92 | + LLVM::LLVMStructType convertedRetTy = |
| 93 | + cast<LLVM::LLVMStructType>(typeConverter->convertType(op.getType())); |
| 94 | + Type convertedPtrTy = convertedRetTy.getBody()[0]; |
| 95 | + |
| 96 | + // map an SMEM ptr in mem space 3 to a ptr in mem space 7 |
| 97 | + auto remotePtr = rewriter.create<NVVM::MapaOp>( |
| 98 | + loc, convertedPtrTy, srcSmemPtr, adaptor.getCtaRank()); |
| 99 | + |
| 100 | + // everything stays the same except base ptr comparing to srcSmemObj |
| 101 | + auto dstSmemObj = SharedMemoryObject( |
| 102 | + remotePtr, srcSmemObj.getBaseElemType(), srcSmemObj.getOffsets()); |
| 103 | + auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter); |
| 104 | + rewriter.replaceOp(op, retVal); |
| 105 | + return success(); |
| 106 | + } |
| 107 | +}; |
| 108 | + |
66 | 109 | } // namespace |
67 | 110 |
|
68 | 111 | void mlir::triton::NVIDIA::populateClusterOpsToLLVMPatterns( |
69 | 112 | LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, |
70 | 113 | PatternBenefit benefit) { |
71 | 114 | patterns.add<ClusterArriveOpConversion>(typeConverter, benefit); |
72 | 115 | patterns.add<ClusterWaitOpConversion>(typeConverter, benefit); |
| 116 | + patterns.add<MapToRemoteBufferOpConversion>(typeConverter, benefit); |
73 | 117 | return; |
74 | 118 | } |
0 commit comments