@@ -110,167 +110,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
110110 return success ();
111111 }
112112
113- SmallVector<Value> transferWithinBlockSwizzlingImpl (
114- Location loc, ConversionPatternRewriter &rewriter,
115- const LinearLayout &srcLayout, const LinearLayout &dstLayout,
116- ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
117- auto *ctx = rewriter.getContext ();
118- auto b = TritonLLVMOpBuilder (loc, rewriter);
119- // We handle transformations recursively as they all need a preprocessing
120- // and a postprocessing step.
121-
122- // Handle pointer types as 64-bit integers
123- if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
124- auto llvmElemTyPtr = i64_ty;
125- auto newInVals = llvm::to_vector (llvm::map_range (inVals, [&](Value v) {
126- return b.ptrtoint (llvmElemTyPtr, v).getResult ();
127- }));
128- auto outVals =
129- transferWithinBlockSwizzlingImpl (loc, rewriter, srcLayout, dstLayout,
130- newInVals, llvmElemTyPtr, smemBase);
131- for (auto &v : outVals) {
132- v = b.inttoptr (llvmElemTy, v);
133- }
134- return outVals;
135- }
136-
137- // Handle sub-byte elements like i1
138- if (llvmElemTy.getIntOrFloatBitWidth () < 8 ) {
139- // Upcast to i8
140- auto i8ElemTy = i8_ty;
141- auto newInVals = llvm::to_vector (llvm::map_range (
142- inVals, [&](Value v) { return b.zext (i8ElemTy, v).getResult (); }));
143- auto outVals = transferWithinBlockSwizzlingImpl (
144- loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
145- for (auto &v : outVals) {
146- v = b.trunc (llvmElemTy, v);
147- }
148- return outVals;
149- }
150-
151- // Remove broadcasting in src
152- auto removeBroadcastSrc = actionRemoveBroadcastedRegs (srcLayout);
153- if (!removeBroadcastSrc.isIdentity ()) {
154- auto prmtSrc = removeBroadcastSrc.apply (srcLayout);
155- auto newInVals = removeBroadcastSrc.apply (inVals);
156- return transferWithinBlockSwizzlingImpl (loc, rewriter, prmtSrc, dstLayout,
157- newInVals, llvmElemTy, smemBase);
158- }
159-
160- // Remove broadcasting in dst
161- auto removeBroadcastDst = actionRemoveBroadcastedRegs (dstLayout);
162- if (!removeBroadcastDst.isIdentity ()) {
163- auto prmtDst = removeBroadcastDst.apply (dstLayout);
164- auto outVals = transferWithinBlockSwizzlingImpl (
165- loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
166- return broadcastAs (outVals, dstLayout);
167- }
168-
169- // At this point we have a type that's at least 8-bit
170- // and we don't have broadcasting in the registers
171- auto bitwidth = llvmElemTy.getIntOrFloatBitWidth ();
172- auto smem = optimalSwizzling (srcLayout, dstLayout, bitwidth);
173-
174- // Extract reps from smem
175- auto kReg = str_attr (" register" );
176- auto kReps = str_attr (" reps" );
177- auto nReps = smem.getInDimSize (kReps );
178- auto reps = LinearLayout::identity1D (nReps, kReg , kReps );
179-
180- auto totalStoreCvt = srcLayout.invertAndCompose (smem);
181- auto totalLoadCvt = dstLayout.invertAndCompose (smem);
182-
183- // The permutation exists by construction of the reps dimension in
184- // optimalSwizzling
185- auto permStore =
186- regPermForDivide (totalStoreCvt, reps, /* left=*/ false ).value ();
187- totalStoreCvt = permStore.apply (totalStoreCvt);
188- auto permutedInVals = permStore.apply (inVals);
189- auto permLoad =
190- regPermForDivide (totalLoadCvt, reps, /* left=*/ false ).value ();
191- totalLoadCvt = permLoad.apply (totalLoadCvt);
192-
193- // Remove the reps and flatten into offset
194- auto storeCvt = *divideRight (totalStoreCvt, reps);
195- auto loadCvt = *divideRight (totalLoadCvt, reps);
196- auto kOffset = str_attr (" offset" );
197- storeCvt = storeCvt.reshapeOuts ({{kOffset , storeCvt.getTotalOutDimSize ()}});
198- loadCvt = loadCvt.reshapeOuts ({{kOffset , loadCvt.getTotalOutDimSize ()}});
199-
200- auto tileSize = storeCvt.getInDimSize (kReg );
201-
202- assert (permutedInVals.size () == tileSize * nReps);
203- SmallVector<Value> outVals;
204- auto noPaddingOffset = [](Value v) { return v; };
205- auto affineOffset = b.i32_val (0 );
206- auto maskSpanAffineOffset = 0 ;
207- for (int i = 0 ; i < nReps; ++i) {
208- if (i > 0 )
209- b.barrier ();
210-
211- auto tileInVals =
212- ArrayRef<Value>(permutedInVals).slice (i * tileSize, tileSize);
213- // Store
214- lowerLdStShared (loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase,
215- noPaddingOffset, affineOffset, maskSpanAffineOffset,
216- rewriter, targetInfo);
217- b.barrier ();
218- // Load
219- SmallVector<Value> tileOutVals = lowerLdStShared (
220- loc, ctx, loadCvt, {}, llvmElemTy, smemBase, noPaddingOffset,
221- affineOffset, maskSpanAffineOffset, rewriter, targetInfo);
222- llvm::append_range (outVals, tileOutVals);
223- }
224-
225- // Undo the permLoad used to divideRight
226- outVals = permLoad.inverse ().apply (outVals);
227- return outVals;
228- }
229-
230- LogicalResult
231- transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
232- ConversionPatternRewriter &rewriter) const {
233- // Fallback for now to standard lowering if it can use stmatrix
234- auto scratchConfig =
235- getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
236- bool isStMatrix = targetInfo.canUseStMatrix (
237- op.getSrc ().getType (), scratchConfig.repShape ,
238- scratchConfig.paddedRepShape , scratchConfig.order ,
239- /* swizzleByteSize=*/ 0 );
240- if (isStMatrix) {
241- return failure ();
242- }
243-
244- auto loc = op.getLoc ();
245- auto *ctx = op.getContext ();
246- auto srcTy = op.getSrc ().getType ();
247- auto dstTy = op.getType ();
248-
249- // Remove the kBlock dimension from the layout as it's the identity in the
250- // cvt
251- auto srcLayout = toLinearLayout (srcTy);
252- auto dstLayout = toLinearLayout (dstTy);
253- auto kReg = str_attr (" register" );
254- auto kLane = str_attr (" lane" );
255- auto kWarp = str_attr (" warp" );
256- srcLayout = srcLayout.sublayout ({kReg , kLane , kWarp },
257- to_vector (srcLayout.getOutDimNames ()));
258- dstLayout = dstLayout.sublayout ({kReg , kLane , kWarp },
259- to_vector (dstLayout.getOutDimNames ()));
260-
261- auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
262- auto smemBase =
263- LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
264- auto inVals = unpackLLElements (loc, src, rewriter);
265- auto outVals = transferWithinBlockSwizzlingImpl (
266- loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
267-
268- Value result =
269- packLLElements (loc, getTypeConverter (), outVals, rewriter, dstTy);
270- rewriter.replaceOp (op, result);
271- return success ();
272- }
273-
274113 LogicalResult transferWithinBlock (ConvertLayoutOp op,
275114 const LinearLayout &srcLayout,
276115 const LinearLayout &dstLayout,
@@ -280,8 +119,9 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
280119
281120 // Try to use swizzling to implement the conversion
282121 // HACK Remove once XPU tests pass for the swizzling path
283- if (!targetInfo.isXpu () && succeeded (transferWithinBlockSwizzling (
284- op, adaptor.getSrc (), rewriter))) {
122+ if (!targetInfo.isXpu () &&
123+ succeeded (transferWithinBlockSwizzling (op, adaptor.getSrc (), targetInfo,
124+ getTypeConverter (), rewriter))) {
285125 return success ();
286126 }
287127
0 commit comments