|
14 | 14 | #include "triton/Dialect/Triton/IR/Utility.h" |
15 | 15 | #include "triton/Dialect/TritonGPU/IR/Dialect.h" |
16 | 16 | #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" |
| 17 | +#include "triton/Dialect/TritonGPU/IR/Types.h" |
17 | 18 | #include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" |
18 | 19 | #include "triton/Tools/LinearLayout.h" |
19 | 20 | #include "triton/Tools/StrUtil.h" |
@@ -1141,8 +1142,8 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, |
1141 | 1142 | // |
1142 | 1143 | // Returns true on success. |
1143 | 1144 | [[nodiscard]] bool emitTransferBetweenRegistersAndShared( |
1144 | | - RankedTensorType registerTy, MemDescType sharedTy, Type elemLlvmTy, |
1145 | | - std::optional<int32_t> maxVecElems, Value shmemBase, |
| 1145 | + RankedTensorType registerTy, triton::gpu::MemDescType sharedTy, |
| 1146 | + Type elemLlvmTy, std::optional<int32_t> maxVecElems, Value shmemBase, |
1146 | 1147 | ArrayRef<Value> shmemStrides, Location loc, RewriterBase &rewriter, |
1147 | 1148 | const TargetInfoBase &target, |
1148 | 1149 | std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback); |
@@ -1310,13 +1311,14 @@ inline DenseMap<unsigned, Value> getSwizzledSharedPtrs( |
1310 | 1311 | } |
1311 | 1312 |
|
1312 | 1313 | SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy, |
1313 | | - MemDescType srcTy, Type elemLlvmTy, |
| 1314 | + triton::gpu::MemDescType srcTy, |
| 1315 | + Type elemLlvmTy, |
1314 | 1316 | SharedMemoryObject smemObj, |
1315 | 1317 | Location loc, RewriterBase &rewriter, |
1316 | 1318 | const TargetInfoBase &target); |
1317 | 1319 |
|
1318 | 1320 | void storeDistributedToShared( |
1319 | | - MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, |
| 1321 | + triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy, |
1320 | 1322 | ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides, |
1321 | 1323 | Location loc, RewriterBase &rewriter, const TargetInfoBase &target, |
1322 | 1324 | std::pair<size_t, Type> *const llvmOpCount = nullptr); |
|
0 commit comments