@@ -63,7 +63,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
63
63
} else if (llvm::is_contained (dims, kWarp )) {
64
64
// Case 2: Transfer between values in the same CTA, in which case we move
65
65
// values through shared memory.
66
- return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
66
+ transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
67
+ return success ();
67
68
} else if (llvm::is_contained (dims, kLane )) {
68
69
// Case 3. Transfer between values in the same warp, in which case we try
69
70
// to move values using warp shuffles, though if the pattern is
@@ -74,7 +75,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
74
75
// TODO: Since data is only transferred within a warp over shared memory,
75
76
// we should use `bar.warp.sync` instead of `barrier`, which will improve
76
77
// latency when warps issue barriers on different cycles.
77
- return transferWithinBlock (op, srcLayout, dstLayout, adaptor, rewriter);
78
+ transferWithinBlockSwizzling (op, adaptor.getSrc (), rewriter);
79
+ return success ();
78
80
} else if (llvm::is_contained (dims, kRegister )) {
79
81
// Case 4. Transfer between values in the same thread, in which case we
80
82
// simply reorder the elements of adaptor.getSrc().
@@ -110,24 +112,152 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
110
112
return success ();
111
113
}
112
114
113
- LogicalResult transferWithinBlock (ConvertLayoutOp op,
114
- const LinearLayout &srcLayout,
115
- const LinearLayout &dstLayout,
116
- OpAdaptor adaptor,
117
- ConversionPatternRewriter &rewriter) const {
118
- assert (cvtNeedsSharedMemory (op.getSrc ().getType (), op.getType ()));
115
+ SmallVector<Value> transferWithinBlockSwizzlingImpl (
116
+ Location loc, ConversionPatternRewriter &rewriter,
117
+ const LinearLayout &srcLayout, const LinearLayout &dstLayout,
118
+ ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
119
+ auto *ctx = rewriter.getContext ();
120
+ auto b = TritonLLVMOpBuilder (loc, rewriter);
121
+ // We handle transformations recursively as they all need a preprocessing
122
+ // and a postprocessing step.
123
+
124
+ // Handle pointer types as 64-bit integers
125
+ if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
126
+ auto llvmElemTyPtr = i64_ty;
127
+ auto newInVals = llvm::to_vector (llvm::map_range (inVals, [&](Value v) {
128
+ return b.ptrtoint (llvmElemTyPtr, v).getResult ();
129
+ }));
130
+ auto outVals =
131
+ transferWithinBlockSwizzlingImpl (loc, rewriter, srcLayout, dstLayout,
132
+ newInVals, llvmElemTyPtr, smemBase);
133
+ for (auto &v : outVals) {
134
+ v = b.inttoptr (llvmElemTy, v);
135
+ }
136
+ return outVals;
137
+ }
119
138
120
- // Try to use swizzling to implement the conversion
121
- if (succeeded (transferWithinBlockSwizzling (op, adaptor.getSrc (), targetInfo,
122
- getTypeConverter (), rewriter))) {
123
- return success ();
139
+ // Handle sub-byte elements like i1
140
+ if (llvmElemTy.getIntOrFloatBitWidth () < 8 ) {
141
+ // Upcast to i8
142
+ auto i8ElemTy = i8_ty;
143
+ auto newInVals = llvm::to_vector (llvm::map_range (
144
+ inVals, [&](Value v) { return b.zext (i8ElemTy, v).getResult (); }));
145
+ auto outVals = transferWithinBlockSwizzlingImpl (
146
+ loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
147
+ for (auto &v : outVals) {
148
+ v = b.trunc (llvmElemTy, v);
149
+ }
150
+ return outVals;
124
151
}
125
152
126
- Value result = transferWithinBlockPadding (op, adaptor.getSrc (), targetInfo,
127
- getTypeConverter (), rewriter);
153
+ // Remove broadcasting in src
154
+ auto removeBroadcastSrc = actionRemoveBroadcastedRegs (srcLayout);
155
+ if (!removeBroadcastSrc.isIdentity ()) {
156
+ auto prmtSrc = removeBroadcastSrc.apply (srcLayout);
157
+ auto newInVals = removeBroadcastSrc.apply (inVals);
158
+ return transferWithinBlockSwizzlingImpl (loc, rewriter, prmtSrc, dstLayout,
159
+ newInVals, llvmElemTy, smemBase);
160
+ }
128
161
162
+ // Remove broadcasting in dst
163
+ auto removeBroadcastDst = actionRemoveBroadcastedRegs (dstLayout);
164
+ if (!removeBroadcastDst.isIdentity ()) {
165
+ auto prmtDst = removeBroadcastDst.apply (dstLayout);
166
+ auto outVals = transferWithinBlockSwizzlingImpl (
167
+ loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
168
+ return broadcastAs (outVals, dstLayout);
169
+ }
170
+
171
+ // At this point we have a type that's at least 8-bit
172
+ // and we don't have broadcasting in the registers
173
+ auto bitwidth = llvmElemTy.getIntOrFloatBitWidth ();
174
+ auto smem = optimalSwizzlingLdSt (srcLayout, dstLayout, bitwidth);
175
+
176
+ // Extract reps from smem
177
+ auto kReg = str_attr (" register" );
178
+ auto kReps = str_attr (" reps" );
179
+ auto nReps = smem.getInDimSize (kReps );
180
+ auto reps = LinearLayout::identity1D (nReps, kReg , kReps );
181
+
182
+ auto totalStoreCvt = srcLayout.invertAndCompose (smem);
183
+ auto totalLoadCvt = dstLayout.invertAndCompose (smem);
184
+
185
+ // The permutation exists by construction of the reps dimension in
186
+ // optimalSwizzling
187
+ auto permStore =
188
+ regPermForDivide (totalStoreCvt, reps, /* left=*/ false ).value ();
189
+ totalStoreCvt = permStore.apply (totalStoreCvt);
190
+ auto permutedInVals = permStore.apply (inVals);
191
+ auto permLoad =
192
+ regPermForDivide (totalLoadCvt, reps, /* left=*/ false ).value ();
193
+ totalLoadCvt = permLoad.apply (totalLoadCvt);
194
+
195
+ // Remove the reps and flatten into offset
196
+ auto storeCvt = *divideRight (totalStoreCvt, reps);
197
+ auto loadCvt = *divideRight (totalLoadCvt, reps);
198
+ auto kOffset = str_attr (" offset" );
199
+ storeCvt = storeCvt.reshapeOuts ({{kOffset , storeCvt.getTotalOutDimSize ()}});
200
+ loadCvt = loadCvt.reshapeOuts ({{kOffset , loadCvt.getTotalOutDimSize ()}});
201
+
202
+ auto tileSize = storeCvt.getInDimSize (kReg );
203
+
204
+ assert (permutedInVals.size () == tileSize * nReps);
205
+ SmallVector<Value> outVals;
206
+ auto affineOffset = b.i32_val (0 );
207
+ auto maskSpanAffineOffset = 0 ;
208
+ auto noPaddingOffset = [](Value v) { return v; };
209
+ for (int i = 0 ; i < nReps; ++i) {
210
+ if (i > 0 )
211
+ b.barrier ();
212
+
213
+ auto tileInVals =
214
+ ArrayRef<Value>(permutedInVals).slice (i * tileSize, tileSize);
215
+ // Store
216
+ lowerLdStShared (loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
217
+ noPaddingOffset, affineOffset, maskSpanAffineOffset,
218
+ rewriter, targetInfo);
219
+ b.barrier ();
220
+ // Load
221
+ SmallVector<Value> tileOutVals = lowerLdStShared (
222
+ loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
223
+ affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
224
+ llvm::append_range (outVals, tileOutVals);
225
+ }
226
+
227
+ // Undo the permLoad used to divideRight
228
+ outVals = permLoad.inverse ().apply (outVals);
229
+ return outVals;
230
+ }
231
+
232
+ void transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
233
+ ConversionPatternRewriter &rewriter) const {
234
+ auto loc = op.getLoc ();
235
+ auto *ctx = op.getContext ();
236
+ auto srcTy = op.getSrc ().getType ();
237
+ auto dstTy = op.getType ();
238
+
239
+ // Remove the kBlock dimension from the layout as it's the identity in the
240
+ // cvt
241
+ auto srcLayout = toLinearLayout (srcTy);
242
+ auto dstLayout = toLinearLayout (dstTy);
243
+ auto kReg = str_attr (" register" );
244
+ auto kLane = str_attr (" lane" );
245
+ auto kWarp = str_attr (" warp" );
246
+ srcLayout = srcLayout.sublayout ({kReg , kLane , kWarp },
247
+ to_vector (srcLayout.getOutDimNames ()));
248
+ dstLayout = dstLayout.sublayout ({kReg , kLane , kWarp },
249
+ to_vector (dstLayout.getOutDimNames ()));
250
+
251
+ auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
252
+ auto smemBase =
253
+ LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
254
+ auto inVals = unpackLLElements (loc, src, rewriter);
255
+ auto outVals = transferWithinBlockSwizzlingImpl (
256
+ loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
257
+
258
+ Value result =
259
+ packLLElements (loc, getTypeConverter (), outVals, rewriter, dstTy);
129
260
rewriter.replaceOp (op, result);
130
- return success ();
131
261
}
132
262
133
263
// Use warp shuffles to implement a layout conversion where data only needs to
0 commit comments