Skip to content

Commit 711e2a9

Browse files
AlexAUTmaksleventalyiqian1antiagainstzhanglx13
authored
[Cherry-Pick] gfx950 improvements and bug fixes (#846)
* [AMD] Avoid async load to pipeline for less than 32bit load (triton-lang#7250) We can only use AsyncCopy if the final load width can be >= 4 bytes. `triton::canBeConvertedToAsyncLoad` checks that the vecSize of the source is large enough. Additionally we need to ensure the register to shared layout (blocked+shared) does have enough contiguous elements since we cannot scatter into LDS. Before this PR we will abort compilation instead of falling back to pipelining through registers. * [AMD] Pipeline small tensors w/ registers only on GFX950 (triton-lang#7171) Fixes a perf regression on gfx942 but preserves functionality for gfx950 (and above). * Reland "[AMD] Optimize reduction with v_permlane intrinsics in GFX950" (triton-lang#7321) triton-lang#7291 fixed the LLVM issue that caused correctness problems. Now we can reland this patch. * [Pipeliner] Expose core pipeliner utilities to reuse in AMD backend (triton-lang#7222) This PR exposes (via header) core pipelining utilities/helpers: ```c++ bool hasGpuBarriers(scf::ForOp forOp); bool isSafeToPipeline(scf::ForOp forOp); llvm::MapVector<Operation *, std::pair<int, Operation *>> loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot, triton::ModuleAxisInfoAnalysis &axisInfoAnalysis, int numStages, bool filterSmall = true); void scheduleDistanceOneDependencies(scf::ForOp forOp, CoarseSchedule &schedule); void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule, CoarseSchedule::Cluster afterPrologue); ``` They are directly useable by AMD's pipeliner. Note, this is basically NFC for AMD because AMD's pipeliner simply had copy-paste of the same functions from ~last year. Small API changes: 1. On NV we do not pipeline small loads (vec width < 32b). On AMD we do. The choice is made inside `isPipeliningBeneficial` inside `loadOpsToIndirectionLevel`. To support AMD I have added a flag `filterSmall`. 2. On AMD the load `use`s (computed as a matter of course in `loadOpsToIndirectionLevel`) are used (no pun intended) whereas on NV they are not. To support AMD I keep those `use`s in the `llvm::MapVector<Operation *, std::pair<int, Operation *>>` return from `loadOpsToIndirectionLevel`. These two small changes are the only "non-NFC" changes. * [AMD] Retire local prefetch schedule hint variant (triton-lang#7395) This variant was from some prior experiments. We have a better way to implement later. * [AMD] Retire TritonAMDGPU_OpIdxAttr and TritonAMDGPU_InstCounter (triton-lang#7476) triton-lang#7395 retired the local prefetch schedule variant. This made `TritonAMDGPU_OpIdxAttr` and `TritonAMDGPU_InstCounter` unused which are removed by this PR. * [AMD][NFC] Split createAndSchedule* in stream pipeliner(triton-lang#7514) Splits `createAndScheduleAsyncCopy` and `createAndScheduleStreamCopy` to make it reusable if we want to schedule the ops differently in a future PR. * [AMD] Refactor StreamPipeliner to use more common functions (triton-lang#7526) Further refactoring of Streampipeliner.cpp to use more common pipeliner functionality: `triton::createAllocation`, `triton::createSingleBufferView`, `triton::replaceWithSharedLoad` and a bit of general cleanup. Overall NFC except: - The order of LocalDealloc is reversed now - The memdesc of the subview additionally includes the allocSize Also we had no lit test checking that the LocalLoad consumes the AsyncToken so I adjusted one to include the check. * [AMD] NFC: Refactor stream pipeliner to better encapsulate functions (triton-lang#7540) Mostly moves code around to reduce the dependencies between functions and further splits up functions doing more than one thing (`createAndSchedule*`,` preprocessAndBuildSchedule`). This will also allow us to use more common pipeliner functionality in a future PR, e.g. `createAsyncCopy`. * [FA] Set vecSize=nonKDim for V shared layout to avoid bank conflicts I'll submit a PR upstream later. * [GEMM] Add combine dot_scaled and addF * [AMD][NFC] Consolidate initialization in initSchedule for pipeliner (triton-lang#7556) Moves all initializations of stages to `initSchedule`. Missed this one in the last PRs. * [AMD] NFC: Drop version minor for AMD MFMA layout (triton-lang#7285) AMD's MFMA layout does not need version minor information like NVIDIA. It always defaults to 0 in the current codebase. The PR drops version minor and change to a single `version` parameter for MFMA layout. * [AMD] Add tilesPerWarp parameter to mfma layout (triton-lang#7283) This PR introduces the tilesPerWarp parameter to the MFMA layout. Previously, the MFMA layout assumed that each warp within a CTA tile computed a single MFMA tile. When the tensor was larger than a single CTA tile, these tiles were repeated across the tensor. In this setup, the output tiles computed by each wave were strided by the number of warps per CTA in both row and column dimensions. For instance, with 16 MFMA tiles and warpsPerCTA = [2, 2], the distribution of warps across the MFMA tiles looked like: w0 w1 w0 w1 w2 w3 w2 w3 w0 w1 w0 w1 w2 w3 w2 w3 The new tilesPerWarp parameter allows each warp to compute contiguous MFMA tiles in the row and/or column dimensions. Using the same example with tilesPerWarp = [2, 2], the layout becomes: w0 w0 w1 w1 w0 w0 w1 w1 w2 w2 w3 w3 w2 w2 w3 w3 While this is a general enhancement, the main motivation for introducing this parameter is to improve memory access efficiency for scale tensors in scaled dot operations. Specific patterns and use cases will be implemented in follow-up PRs. --------- Co-authored-by: Ognjen Plavsic <[email protected]> Co-authored-by: Lei Zhang <[email protected]> * [AMD] Add support for pingpong GEMM using async_copy * [BACKEND] combineRedundantWaitOps should not combine across loops/branches (triton-lang#7593) `combineRedundantWaitOps` did skip over branches/loops, so if we end up with something like: ```mlir ttg.async_wait scf.for .... scf.yield ttg.async_wait ``` we merge the async_waits in the prologue and epilogue because we do not find a `ttg.commit_group` in between. This PR stops the forward search if we encounter a branch/loop. I can also walk through all successor blocks if we think this is worth the effort. This problem was not triggered before because the `ttg.async_wait` was scheduled in the same stage as its user(s) so we ended up with no `ttg.async_wait` in the prologue or there was another prefetch after it in the prologue. Since triton-lang#7458 we might place the `ttg.async_wait` in the previous stage compared to its user(s) so we might end up with the problematic IR. * [AMD][NFC] Group scheduling functions in StreamPipeliner (triton-lang#7607) NFC: Groups all scheduling related function to a namespace to prepare for additional scheduling variants. * [AMD] Add pingpong transformation for chained dot schedule (triton-lang#7638) Adds support to enable pingpong for loops scheduled with the new `ChainedDotSchedule` introduced by triton-lang#7601. The schedule already places the ops in the correct order so we just have to insert the sync ops to ensure proper pingpong'ing. * [AMD] Fix pingpong ChainedDot for empty second memory cluster (triton-lang#7694) triton-lang#7638 introduced a null pointer access (during review adjustments) if the second memory cluster is empty or if there are no memory clusters at all. Added a lit test to catch it and revert to the old logic. * [AMD] Remove bypass permute optimization for AsyncCopy (triton-lang#7704) We can only bypass ds_bpermute to apply the swizzling if lanes loading the same row read a contiguous chunk of memory from HBM, which we cannot infer when lowering to LLVM. The current selection does only check if the elements for each lane are contiguous which is not strict enough. * [AMD] Add ChainedDotSchedule to StreamPipeliner (triton-lang#7601) Adds a new scheduling variant which kicks in for loop which have 2 chained dots and `num_stages==4`. It places the two dots in consecutive stages so we can interleave operations using the result of the first dot with both dots in the loop, a pseudo example IR: ``` %1 = tt.dot ... %2 = arith.addf %1, %arg1 %3 = arith.subf %2, %arg2 %4 = tt.dot %X, %Y, %3 ``` Which could result in the following pseudo schedule (ignoring mem ops) to interleave with both dots: ``` stage N, Cluster0: [%1 = tt.dot, %3 = arith.subf] stage N+1, Cluster1: [%4 = tt.dot, %2 = arith.addf] ``` As a first step the schedule splits the op chain between dot1 and dot2 when it encounters an operation which has more than 2 users. This aims to avoid adding too many loop carried dependencies but does not guarantee a good work balance between the two clusters. In future PRs we might make this more sophisticated. * [AMD] Add scale preshuffling and opSel implementation (triton-lang#7603) This PR implements test kernel for efficient scale packing for CDNA4 arch as well as opSel for scaled MFMA instructions. Scaled MFMA instructions expect scale operands as 32-bit values, even though each individual scale is only 8 bits. To reduce register usage, we pack 4 scales into a single 32-bit value and use the opSel field to select the appropriate byte during execution. Packing is done along the K dimension first. if there aren’t enough values in K, we continue along the non-K dimension. --------- Co-authored-by: Ognjen Plavsic <[email protected]> * [AMD] Enable Pingpong by default on gfx950 arch (triton-lang#7697) List of enabling conditions - FP/BF16 GEMM with M,N>64 tilesize when num_stages=3 and num_warps=8 - GEMM using `dot_scaled` with M=N=256 tile size when num_stages=2 and num_warps=8 - FA with num_stages=4 Only with using async_copy. * [Backend] Bump to llvm/llvm-project@570885128351 (triton-lang#7291) This picks up a bug fix for AMDGPU v_permlane_swap: llvm/llvm-project#144423 Without this fix, the v_permlane_swap is wrongly sunk. Along the way we need to fix API changes: Add header file for the class IRBuilder Add missing default parameter in convertFuncOpToLLVMFuncOp --------- Co-authored-by: Maksim Levental <[email protected]> Co-authored-by: Yi Qian <[email protected]> Co-authored-by: Lei Zhang <[email protected]> Co-authored-by: Lixun Zhang <[email protected]> Co-authored-by: Jungwook Park <[email protected]> Co-authored-by: Pengzhan Zhao <[email protected]> Co-authored-by: plognjen <[email protected]> Co-authored-by: Ognjen Plavsic <[email protected]> Co-authored-by: Zeng Wu <[email protected]>
1 parent 5e56853 commit 711e2a9

File tree

80 files changed

+3301
-2063
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+3301
-2063
lines changed

cmake/llvm-hash.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
8957e64a20fc7f4277565c6cfe3e555c119783ce
1+
570885128351868c1308bb22e8ca351d318bc4a1

include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,6 @@ class TargetInfoBase {
9797
virtual bool supportLdMatrix() const { return false; }
9898
virtual bool supportStMatrix() const { return false; }
9999

100-
// Annotate target specific information to local store operations during
101-
// lowering to LLVM.
102-
virtual void localStoreOpAnnotation(triton::gpu::LocalStoreOp op,
103-
size_t localStoreOpCount,
104-
Type type) const {}
105100
// Annotate target specific information to local load operations during
106101
// lowering to LLVM. `llLoadOp` is the generated LLVM load op.
107102
virtual void localLoadOpAnnotation(triton::gpu::LocalLoadOp localLoadOp,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -556,11 +556,12 @@ SmallVector<Value> loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp,
556556
Location loc, RewriterBase &rewriter,
557557
const TargetInfoBase &target);
558558

559-
void storeDistributedToShared(
560-
triton::gpu::MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
561-
ArrayRef<Value> srcVals, const SharedMemoryObject &smemObj, Location loc,
562-
RewriterBase &rewriter, const TargetInfoBase &target,
563-
std::pair<size_t, Type> *const llvmOpCount = nullptr);
559+
void storeDistributedToShared(triton::gpu::MemDescType dstTy,
560+
RankedTensorType srcTy, Type elemLlvmTy,
561+
ArrayRef<Value> srcVals,
562+
const SharedMemoryObject &smemObj, Location loc,
563+
RewriterBase &rewriter,
564+
const TargetInfoBase &target);
564565

565566
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
566567
RewriterBase &rewriter);

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,11 @@ LinearLayout getTmemLoadLayoutSplitLongM(int M, int N, RankedTensorType oldType,
282282
int numWarps);
283283

284284
// Create LinearLayout for scale in scaled mfma.
285-
LinearLayout chooseScaledMfmaScaleLayout(
286-
MLIRContext *ctx, int dotOperandIdx,
287-
const std::vector<std::vector<int32_t>> &dotOperandWarpBasis,
288-
ArrayRef<int64_t> dotOperandShape, unsigned mfmaMDim);
285+
LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
286+
ArrayRef<int64_t> dotOperandShape,
287+
unsigned mfmaMDim,
288+
ArrayRef<unsigned> tilesPerWarp,
289+
ArrayRef<unsigned> warpsPerCTA);
289290

290291
// Create LinearLayout for nvidia mma tile.
291292
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -919,11 +919,11 @@ An encoding for tensors that have been produced by MFMA matrix core instructions
919919
available on AMD Instinct GPUs of CDNA architectures.
920920

921921
It is characterized by the following parameters:
922-
- `versionMajor` and `versionMinor` indicates the GPU architecture:
923-
- 1.0: gfx908, i.e. CDNA1
924-
- 2.0: gfx90a: i.e. CDNA2
925-
- 3.0: gfx942: CDNA3
926-
- 4.0: gfx950: CDNA4
922+
- `version` indicates the GPU architecture:
923+
- 1: gfx908: CDNA1
924+
- 2: gfx90a: CDNA2
925+
- 3: gfx942: CDNA3
926+
- 4: gfx950: CDNA4
927927
- `warpsPerCTA` indicates the warp layout in the block.
928928
- `MDim` and `NDim` indicate the dimension of the output of the mfma instruction.
929929
- `isTransposed` indicates the result tensor is transposed so that it can be converted to dotOperand layout
@@ -1009,24 +1009,61 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
10091009
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
10101010
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
10111011
[ 64,68...124 65,69...125 66,70...126 67,71...127 ] [ 192,196...252 193,197...253 194,198...254 195,199...255 ]
1012+
1013+
Example 4:
1014+
This example demonstrates semantics of tilesPerWarp parameter. The MFMA layout (with tilesPerWarp=[1,1])
1015+
assumes that each warp within a CTA tile computes a single MFMA tile. When the tensor is larger than
1016+
a single CTA tile, these tiles are repeated across the tensor. In this setup, the output tiles computed
1017+
by each wave were strided by the number of warps per CTA tile in both row and column dimensions.
1018+
1019+
For instance, with 16 MFMA tiles and warpsPerCTA = [2, 2], the distribution of warps across the MFMA
1020+
tiles looked like:
1021+
1022+
w0 w1 w0 w1
1023+
w2 w3 w2 w3
1024+
w0 w1 w0 w1
1025+
w2 w3 w2 w3
1026+
1027+
tilesPerWarp parameter allows each warp to compute contiguous MFMA tiles in the row and/or column dimensions.
1028+
Using the same example with tilesPerWarp = [2, 2], the layout becomes:
1029+
1030+
w0 w0 w1 w1
1031+
w0 w0 w1 w1
1032+
w2 w2 w3 w3
1033+
w2 w2 w3 w3
10121034
}];
10131035

10141036
let parameters = (
10151037
ins
1016-
"unsigned": $versionMajor,
1017-
"unsigned": $versionMinor,
1038+
"unsigned": $version,
10181039
ArrayRefParameter<"unsigned">:$warpsPerCTA,
1040+
ArrayRefParameter<"unsigned">:$tilesPerWarp,
10191041
"unsigned":$MDim,
10201042
"unsigned":$NDim,
10211043
"bool":$isTransposed,
10221044
"CTALayoutAttr":$CTALayout
10231045
);
10241046

1047+
let builders = [
1048+
AttrBuilder<(ins "unsigned":$version,
1049+
"ArrayRef<unsigned>":$warpsPerCTA,
1050+
"unsigned":$MDim,
1051+
"unsigned":$NDim,
1052+
"bool":$isTransposed,
1053+
"CTALayoutAttr":$CTALayout), [{
1054+
SmallVector<unsigned> tilesPerWarp(warpsPerCTA.size(), 1);
1055+
return $_get(context, version, warpsPerCTA, tilesPerWarp, MDim, NDim, isTransposed, CTALayout);
1056+
}]>
1057+
];
1058+
10251059
let extraClassDeclaration = extraDistributedDeclaration # [{
10261060
SmallVector<int64_t> getInstrShapeForOperand(int kWidth, int opIdx) const;
10271061
SmallVector<int64_t> getRepForOperand(ArrayRef<int64_t> operandShape, int kWidth, int opIdx) const;
10281062
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
10291063

1064+
// Check if tilesPerWarp is 1 in every dimension.
1065+
bool hasUnitTilesPerWarp() const;
1066+
10301067
// Returns a swizzled shared layout matching this MFMA layout for the
10311068
// dot operand at the given |operandIdx| with |operandShape|.
10321069
SwizzledSharedEncodingAttr composeSharedLayoutForOperand(

include/triton/Dialect/TritonGPU/Transforms/Schedule.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "mlir/Dialect/SCF/IR/SCF.h"
55
#include "mlir/IR/ImplicitLocOpBuilder.h"
66
#include "mlir/Support/LLVM.h"
7+
#include "triton/Analysis/AxisInfo.h"
78
#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h"
89
#include "llvm/ADT/ArrayRef.h"
910
#include <list>
@@ -17,6 +18,13 @@ namespace gpu {
1718
/// Lower the loops to prepare them for pipeline expansion.
1819
void lowerLoops(ModuleOp moduleOp);
1920

21+
bool hasGpuBarriers(scf::ForOp forOp);
22+
bool isSafeToPipeline(scf::ForOp forOp);
23+
llvm::MapVector<Operation *, std::pair<int, Operation *>>
24+
loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
25+
triton::ModuleAxisInfoAnalysis &axisInfoAnalysis,
26+
int numStages, bool filterSmall = true);
27+
2028
}; // namespace gpu
2129

2230
/// Pipeline the TMA stores in the loop.
@@ -191,6 +199,13 @@ class OpBuilderForStage : public mlir::ImplicitLocOpBuilder,
191199
CoarseSchedule &schedule;
192200
};
193201

202+
namespace gpu {
203+
void scheduleDistanceOneDependencies(scf::ForOp forOp,
204+
CoarseSchedule &schedule);
205+
void scheduleRemainingToLastStage(scf::ForOp forOp, CoarseSchedule &schedule,
206+
CoarseSchedule::Cluster afterPrologue);
207+
} // namespace gpu
208+
194209
} // namespace triton
195210
} // namespace mlir
196211
#endif // TRITON_TRITONGPU_TRANSFORM_PIPELINE_SCHEDULE_H_

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ bool isPureUnaryInlineAsm(Operation *op);
206206
int getNVIDIAComputeCapability(Operation *module);
207207

208208
// Read the amd target from the module attributes
209-
StringRef getAMDArch(Operation *module);
209+
std::optional<StringRef> getAMDArch(Operation *module);
210210

211211
std::optional<mlir::triton::gpu::SwizzledSharedEncodingAttr>
212212
getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible);
@@ -258,11 +258,12 @@ void replaceUsesAndPropagateType(OpBuilder &builder, Operation *oldUse,
258258

259259
/// Replace all uses of `old` with a local load from `alloc` unless the use is a
260260
/// `ttg.local_alloc` with a matching shared encoding, in which case the shared
261-
/// memory is forwarded directly into the use.
262-
void replaceUsesWithLocalLoad(
263-
OpBuilder &builder, OpResult old,
264-
TypedValue<triton::gpu::MemDescType> alloc,
265-
TypedValue<triton::gpu::AsyncTokenType> token = {});
261+
/// memory is forwarded directly into the use. Returns the `ttg.local_load` if
262+
/// it created one.
263+
triton::gpu::LocalLoadOp
264+
replaceUsesWithLocalLoad(OpBuilder &builder, OpResult old,
265+
TypedValue<triton::gpu::MemDescType> alloc,
266+
TypedValue<triton::gpu::AsyncTokenType> token = {});
266267

267268
// Return true if the value comes from a load or a block argument.
268269
// This will skip convert layouts and memdesc views.

lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace {
1818
using namespace mlir;
1919
using namespace mlir::triton::gpu;
2020

21+
constexpr int kPtrBitWidth = 64;
2122
struct ConvertLayoutOpUsingLinearLayoutsConversion
2223
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
2324
const TargetInfoBase &targetInfo;

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,9 @@
1+
#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h"
12
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
23
#include "mlir/IR/BuiltinAttributes.h"
34
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
45
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
56

6-
namespace mlir {
7-
FailureOr<LLVM::LLVMFuncOp>
8-
convertFuncOpToLLVMFuncOp(FunctionOpInterface funcOp,
9-
ConversionPatternRewriter &rewriter,
10-
const LLVMTypeConverter &converter);
11-
}
12-
137
namespace {
148

159
using namespace mlir;

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,19 @@ using namespace mlir::triton::gpu;
1515
// blocked -> shared.
1616
// Swizzling in shared memory to avoid bank conflict. Normally used for
1717
// A/B operands of dots.
18-
void lowerDistributedToShared(
19-
Location loc, Value src, Value dst, Value adaptorSrc,
20-
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
21-
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
22-
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
18+
void lowerDistributedToShared(Location loc, Value src, Value dst,
19+
Value adaptorSrc,
20+
const SharedMemoryObject &smemObj,
21+
const LLVMTypeConverter *typeConverter,
22+
ConversionPatternRewriter &rewriter,
23+
const TargetInfoBase &targetInfo) {
2324
auto srcTy = cast<RankedTensorType>(src.getType());
2425
auto dstTy = cast<MemDescType>(dst.getType());
2526
auto elemTy = typeConverter->convertType(srcTy.getElementType());
2627

2728
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
2829
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
29-
targetInfo, llvmOpCount);
30+
targetInfo);
3031
}
3132

3233
struct GlobalScratchAllocOpConversion
@@ -173,13 +174,9 @@ struct LocalStoreOpConversion
173174
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
174175
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);
175176

176-
std::pair<size_t, Type> llvmOpCount;
177177
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
178178
adaptor.getSrc(), smemObj, getTypeConverter(),
179-
rewriter, targetInfo, &llvmOpCount);
180-
181-
targetInfo.localStoreOpAnnotation(op, llvmOpCount.first,
182-
llvmOpCount.second);
179+
rewriter, targetInfo);
183180

184181
rewriter.eraseOp(op);
185182
return success();

0 commit comments

Comments
 (0)