@@ -114,190 +114,65 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
114114 return success ();
115115 }
116116
117- std::pair<int , ColumnAction> largestVectorisation (MLIRContext *ctx,
118- const LinearLayout &cvt,
119- int bitwidth) const {
120- // Find the largest vectorisation we can use:
121- StringAttr kReg = str_attr (" register" );
122- StringAttr kOffset = str_attr (" offset" );
123- LinearLayout quot;
124- LinearLayout tile;
125- ColumnAction permutation;
126- for (int v = 128 / bitwidth; v >= 1 ; v /= 2 ) {
127- tile = LinearLayout::identity1D (v, kReg , kOffset );
128- auto maybePerm = regPermForDivide (cvt, tile, /* left=*/ true );
129- if (!maybePerm) {
130- continue ;
131- }
132- permutation = *maybePerm;
133- auto newCvt = permutation.apply (cvt);
134- auto maybeQuot = divideLeft (newCvt, tile);
135- if (!maybeQuot) {
136- continue ;
137- }
138- return {v, permutation};
139- }
140- llvm_unreachable (" No vectorisation found" );
141- }
142-
143- // Close cousin of lowerLdStMatrix in MemoryOpToLLVM.cpp
144- // We might want to merge them at some point, but having to support
145- // ldmatrix.trans makes the code in lowerLdStMatrix a bit specific
146- // Lowers to st when valArrays is empty, and to ld when it is not,
147- // and returns the output values.
148- SmallVector<Value>
149- lowerLdStShared (Location loc, MLIRContext *ctx, LinearLayout cvt,
150- int elemsPerVec,
151- ArrayRef<Value> valsArray, // Input for store, output for load
152- Type llvmElemTy, Value smemBase,
153- ConversionPatternRewriter &rewriter) const {
154- auto vals = to_vector (valsArray);
155- bool isStore = !vals.empty ();
117+ SmallVector<Value> transferWithinBlockSwizzlingImpl (
118+ Location loc, ConversionPatternRewriter &rewriter,
119+ const LinearLayout &srcLayout, const LinearLayout &dstLayout,
120+ ArrayRef<Value> inVals, Type llvmElemTy, Value smemBase) const {
121+ auto *ctx = rewriter.getContext ();
156122 auto b = TritonLLVMOpBuilder (loc, rewriter);
157- auto smemPtrTy = ptr_ty (ctx, 3 );
158- auto kReg = str_attr (" register" );
159- auto kLane = str_attr (" lane" );
160- auto kWarp = str_attr (" warp" );
161- auto kOffset = str_attr (" offset" );
162- auto bitwidth = llvmElemTy.getIntOrFloatBitWidth ();
163-
164- auto [vec, permutation] = largestVectorisation (ctx, cvt, bitwidth);
165- assert (vec >= elemsPerVec);
166- elemsPerVec = vec;
167-
168- cvt = permutation.apply (cvt);
169- if (isStore) {
170- vals = permutation.apply (vals);
171- }
172-
173- auto tile = LinearLayout::identity1D (vec, kReg , kOffset );
174- auto quot = *divideLeft (cvt, tile);
175- LinearLayout reps = zerosLike (tile) * quot;
176-
177- auto [nAdditive, permStrides] = actionAdditiveStrides (reps);
178- reps = permStrides.apply (reps);
179- if (isStore) {
180- vals = permStrides.apply (vals);
181- }
182-
183- // PTX expects the address increments to be done in bytes
184- // If we don't perform the computations in i8, the compiler would
185- // have to divide the computation by bitwdith / 8 and then lift this
186- // shl, which often it's not able to do.
187- auto i8Tile =
188- zerosLike (LinearLayout::identity1D (bitwidth / 8 , kReg , kOffset ));
189- auto i8Reps = i8Tile * reps;
190-
191- auto [laneId, warpId] = getLaneAndWarpId (rewriter, loc);
192- auto regBaseI8 =
193- applyLinearLayout (
194- loc, rewriter, i8Reps,
195- {{kReg , b.i32_val (0 )}, {kLane , laneId}, {kWarp , warpId}})[0 ]
196- .second ;
197- SmallVector<Value> outVals;
198- for (int i = 0 ; i < cvt.getInDimSize (kReg ); i += nAdditive) {
199- auto regIdx = reps.apply ({{kReg , i}, {kLane , 0 }, {kWarp , 0 }})[0 ].second ;
200- auto regIdxI8 = regIdx * (bitwidth / 8 );
201- Value offset = b.xor_ (regBaseI8, b.i32_val (regIdxI8));
202- for (int j = 0 ; j < nAdditive; j += elemsPerVec) {
203- // all these constants will go as immediate values to LDS/STS
204- auto regIdxAdd =
205- reps.apply ({{kReg , j}, {kLane , 0 }, {kWarp , 0 }})[0 ].second ;
206- auto regIdxAddI8 = regIdxAdd * (bitwidth / 8 );
207- Value innerOffset = b.add (offset, b.i32_val (regIdxAddI8));
208- auto vecAddr = b.gep (smemPtrTy, i8_ty, smemBase, innerOffset,
209- LLVM::GEPNoWrapFlags::inbounds);
210- // Lezcano: Do we want to use getFreeVariableMasks for pred or nah?
211- if (isStore) {
212- Value valsVec = packLLVector (
213- loc, ArrayRef<Value>(vals).slice (i + j, elemsPerVec), rewriter);
214- targetInfo.storeDShared (rewriter, loc, vecAddr, std::nullopt , valsVec,
215- /* pred=*/ b.true_val ());
216- } else {
217- Value valsVec =
218- targetInfo.loadDShared (rewriter, loc, vecAddr, std::nullopt ,
219- vec_ty (llvmElemTy, elemsPerVec),
220- /* pred=*/ b.true_val ());
221- llvm::append_range (outVals, unpackLLVector (loc, valsVec, rewriter));
222- }
123+ // We handle transformations recursively as they all need a preprocessing
124+ // and a postprocessing step.
125+
126+ // Handle pointer types as 64-bit integers
127+ if (isa<LLVM::LLVMPointerType>(llvmElemTy)) {
128+ auto llvmElemTyPtr = i64_ty;
129+ auto newInVals = llvm::to_vector (llvm::map_range (inVals, [&](Value v) {
130+ return b.ptrtoint (llvmElemTyPtr, v).getResult ();
131+ }));
132+ auto outVals =
133+ transferWithinBlockSwizzlingImpl (loc, rewriter, srcLayout, dstLayout,
134+ newInVals, llvmElemTyPtr, smemBase);
135+ for (auto &v : outVals) {
136+ v = b.inttoptr (llvmElemTy, v);
223137 }
138+ return outVals;
224139 }
225140
226- // Permute the values back if we are loading
227- if (!isStore) {
228- auto invPermStrides = permStrides.inverse ();
229- outVals = invPermStrides.apply (outVals);
230- auto invPerm = permutation.inverse ();
231- outVals = invPerm.apply (outVals);
232- }
233- return outVals;
234- }
235-
236- LogicalResult
237- transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
238- ConversionPatternRewriter &rewriter) const {
239- // Fallback for now to standard lowering if it can use stmatrix
240- auto scratchConfig =
241- getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
242- bool isStMatrix = targetInfo.canUseStMatrix (
243- op.getSrc ().getType (), scratchConfig.repShape ,
244- scratchConfig.paddedRepShape , scratchConfig.order ,
245- /* swizzleByteSize=*/ 0 );
246- if (isStMatrix) {
247- return failure ();
248- }
249-
250- auto loc = op.getLoc ();
251- auto *ctx = op.getContext ();
252- auto b = TritonLLVMOpBuilder (loc, rewriter);
253- auto srcTy = op.getSrc ().getType ();
254- auto dstTy = op.getType ();
255- auto bitwidth = isa<PointerType>(srcTy.getElementType ())
256- ? kPtrBitWidth
257- : srcTy.getElementTypeBitWidth ();
258-
259- auto srcLayout = toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
260- auto dstLayout = toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
261- auto origDstLayout = dstLayout;
262-
263- // We remove the Block dimension from the layout as it's the identity in the
264- // cvt
265- auto kRegister = str_attr (" register" );
266- auto kLane = str_attr (" lane" );
267- auto kWarp = str_attr (" warp" );
268- srcLayout = srcLayout.sublayout ({kRegister , kLane , kWarp },
269- to_vector (srcLayout.getOutDimNames ()));
270- dstLayout = dstLayout.sublayout ({kRegister , kLane , kWarp },
271- to_vector (dstLayout.getOutDimNames ()));
272-
273141 // Handle sub-byte elements like i1
274- auto inVals = unpackLLElements (loc, src, rewriter);
275-
276- bool isSubByte = bitwidth < 8 ;
277- auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
278- if (isSubByte) {
142+ if (llvmElemTy.getIntOrFloatBitWidth () < 8 ) {
279143 // Upcast to i8
280- bitwidth = 8 ;
281- llvmElemTy = i8_ty;
282- for (auto &v : inVals) {
283- v = b.zext (llvmElemTy, v);
284- }
285- }
286- bool isPtr = isa<PointerType>(srcTy.getElementType ());
287- if (isPtr) {
288- llvmElemTy =
289- getTypeConverter ()->convertType (IntegerType::get (ctx, kPtrBitWidth ));
290- for (auto &v : inVals) {
291- v = b.ptrtoint (llvmElemTy, v);
144+ auto i8ElemTy = i8_ty;
145+ auto newInVals = llvm::to_vector (llvm::map_range (
146+ inVals, [&](Value v) { return b.zext (i8ElemTy, v).getResult (); }));
147+ auto outVals = transferWithinBlockSwizzlingImpl (
148+ loc, rewriter, srcLayout, dstLayout, newInVals, i8ElemTy, smemBase);
149+ for (auto &v : outVals) {
150+ v = b.trunc (llvmElemTy, v);
292151 }
152+ return outVals;
293153 }
294154
295- // Remove register broadcast from src and dst and input values
155+ // Remove broadcasting in src
296156 auto removeBroadcastSrc = actionRemoveBroadcastedRegs (srcLayout);
297- srcLayout = removeBroadcastSrc.apply (srcLayout);
298- inVals = removeBroadcastSrc.apply (inVals);
299- dstLayout = actionRemoveBroadcastedRegs (dstLayout).apply (dstLayout);
157+ if (!removeBroadcastSrc.isIdentity ()) {
158+ auto prmtSrc = removeBroadcastSrc.apply (srcLayout);
159+ auto newInVals = removeBroadcastSrc.apply (inVals);
160+ return transferWithinBlockSwizzlingImpl (loc, rewriter, prmtSrc, dstLayout,
161+ newInVals, llvmElemTy, smemBase);
162+ }
300163
164+ // Remove broadcasting in dst
165+ auto removeBroadcastDst = actionRemoveBroadcastedRegs (dstLayout);
166+ if (!removeBroadcastDst.isIdentity ()) {
167+ auto prmtDst = removeBroadcastDst.apply (dstLayout);
168+ auto outVals = transferWithinBlockSwizzlingImpl (
169+ loc, rewriter, srcLayout, prmtDst, inVals, llvmElemTy, smemBase);
170+ return broadcastAs (outVals, dstLayout);
171+ }
172+
173+ // At this point we have a type that's at least 8-bit
174+ // and we don't have broadcasting in the registers
175+ auto bitwidth = llvmElemTy.getIntOrFloatBitWidth ();
301176 auto smem = optimalSwizzling (srcLayout, dstLayout, bitwidth);
302177
303178 // Extract reps from smem
@@ -314,7 +189,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
314189 auto permStore =
315190 regPermForDivide (totalStoreCvt, reps, /* left=*/ false ).value ();
316191 totalStoreCvt = permStore.apply (totalStoreCvt);
317- inVals = permStore.apply (inVals);
192+ auto permutedInVals = permStore.apply (inVals);
318193 auto permLoad =
319194 regPermForDivide (totalLoadCvt, reps, /* left=*/ false ).value ();
320195 totalLoadCvt = permLoad.apply (totalLoadCvt);
@@ -326,49 +201,68 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
326201 storeCvt = storeCvt.reshapeOuts ({{kOffset , storeCvt.getTotalOutDimSize ()}});
327202 loadCvt = loadCvt.reshapeOuts ({{kOffset , loadCvt.getTotalOutDimSize ()}});
328203
329- Value smemBase =
330- LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
331-
332204 auto tileSize = storeCvt.getInDimSize (kReg );
333205
334- assert (inVals .size () == tileSize * nReps);
206+ assert (permutedInVals .size () == tileSize * nReps);
335207 SmallVector<Value> outVals;
336- auto elemsPerVec = smem.getInDimSize (str_attr (" vector" ));
337208 for (int i = 0 ; i < nReps; ++i) {
338209 if (i > 0 )
339210 b.barrier ();
340211
341- auto tileInVals = ArrayRef<Value>(inVals).slice (i * tileSize, tileSize);
212+ auto tileInVals =
213+ ArrayRef<Value>(permutedInVals).slice (i * tileSize, tileSize);
342214 // Store
343- lowerLdStShared (loc, ctx, storeCvt, elemsPerVec, tileInVals, llvmElemTy,
344- smemBase, rewriter );
215+ lowerLdStShared (loc, ctx, storeCvt, tileInVals, llvmElemTy, smemBase ,
216+ rewriter, targetInfo );
345217 b.barrier ();
346218 // Load
347219 SmallVector<Value> tileOutVals = lowerLdStShared (
348- loc, ctx, loadCvt, elemsPerVec, {}, llvmElemTy, smemBase, rewriter);
220+ loc, ctx, loadCvt, {}, llvmElemTy, smemBase, rewriter, targetInfo );
349221 llvm::append_range (outVals, tileOutVals);
350222 }
351223
352224 // Undo the permLoad used to divideRight
353225 outVals = permLoad.inverse ().apply (outVals);
226+ return outVals;
227+ }
354228
355- // Unwrap sub-byte elements if necessary
356- if (isSubByte) {
357- auto llvmElemTyOrig =
358- getTypeConverter ()->convertType (srcTy.getElementType ());
359- for (auto &v : outVals) {
360- v = b.trunc (llvmElemTyOrig, v);
361- }
362- } else if (isPtr) {
363- auto llvmElemTyOrig =
364- getTypeConverter ()->convertType (srcTy.getElementType ());
365- for (auto &v : outVals) {
366- v = b.inttoptr (llvmElemTyOrig, v);
367- }
229+ LogicalResult
230+ transferWithinBlockSwizzling (ConvertLayoutOp op, Value src,
231+ ConversionPatternRewriter &rewriter) const {
232+ // Fallback for now to standard lowering if it can use stmatrix
233+ auto scratchConfig =
234+ getScratchConfigForCvt (op.getSrc ().getType (), op.getType ());
235+ bool isStMatrix = targetInfo.canUseStMatrix (
236+ op.getSrc ().getType (), scratchConfig.repShape ,
237+ scratchConfig.paddedRepShape , scratchConfig.order ,
238+ /* swizzleByteSize=*/ 0 );
239+ if (isStMatrix) {
240+ return failure ();
368241 }
369242
370- // Undo the removeBroadcastSrc
371- outVals = broadcastAs (outVals, origDstLayout);
243+ auto loc = op.getLoc ();
244+ auto *ctx = op.getContext ();
245+ auto srcTy = op.getSrc ().getType ();
246+ auto dstTy = op.getType ();
247+
248+ // Remove the kBlock dimension from the layout as it's the identity in the
249+ // cvt
250+ auto srcLayout = toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
251+ auto dstLayout = toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
252+ auto kReg = str_attr (" register" );
253+ auto kLane = str_attr (" lane" );
254+ auto kWarp = str_attr (" warp" );
255+ srcLayout = srcLayout.sublayout ({kReg , kLane , kWarp },
256+ to_vector (srcLayout.getOutDimNames ()));
257+ dstLayout = dstLayout.sublayout ({kReg , kLane , kWarp },
258+ to_vector (dstLayout.getOutDimNames ()));
259+
260+ auto llvmElemTy = getTypeConverter ()->convertType (srcTy.getElementType ());
261+ auto smemBase =
262+ LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
263+ auto inVals = unpackLLElements (loc, src, rewriter);
264+ auto outVals = transferWithinBlockSwizzlingImpl (
265+ loc, rewriter, srcLayout, dstLayout, inVals, llvmElemTy, smemBase);
372266
373267 Value result =
374268 packLLElements (loc, getTypeConverter (), outVals, rewriter, dstTy);
0 commit comments