@@ -218,6 +218,46 @@ struct ReduceOpConversion
218218 rewriter.replaceOp (op, results);
219219 }
220220
221+ // For slice layout some ids are duplicated on multiple lanes, so we need to
222+ // handle the delinearization of laneId in a special way. We need to
223+ // generalize this part of the logic to work on any kind of linear layout
224+ // uniformely.
225+ SmallVector<Value>
226+ getMultiDimLaneId (ReduceOpHelper &helper, Value &laneId, Location &loc,
227+ ConversionPatternRewriter &rewriter) const {
228+ auto srcLayout = helper.getSrcLayout ();
229+ auto srcShape = helper.getSrcShape ();
230+ auto order = triton::gpu::getThreadOrder (srcLayout);
231+ SmallVector<Value> multiDimLaneId;
232+
233+ if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
234+ auto parentLayout = sliceLayout.getParent ();
235+ SmallVector<unsigned > dims = {sliceLayout.getDim ()};
236+ while (auto parentSliceLayout =
237+ mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
238+ dims.push_back (parentSliceLayout.getDim ());
239+ parentLayout = parentSliceLayout.getParent ();
240+ }
241+
242+ auto parentThreadsPerWarps = triton::gpu::getThreadsPerWarp (parentLayout);
243+ auto parentOrder = triton::gpu::getThreadOrder (parentLayout);
244+ multiDimLaneId = delinearize (rewriter, loc, laneId, parentThreadsPerWarps,
245+ parentOrder);
246+ for (unsigned dim : llvm::reverse (dims)) {
247+ multiDimLaneId.erase (multiDimLaneId.begin () + dim);
248+ }
249+ } else {
250+ SmallVector<unsigned > threadsPerWarps =
251+ triton::gpu::getThreadsPerWarp (srcLayout);
252+ threadsPerWarps[helper.getAxis ()] =
253+ triton::gpu::getThreadsPerWarpWithUniqueData (
254+ srcLayout, srcShape)[helper.getAxis ()];
255+ multiDimLaneId =
256+ delinearize (rewriter, loc, laneId, threadsPerWarps, order);
257+ }
258+ return multiDimLaneId;
259+ }
260+
221261 SmallVector<Value>
222262 getMultiDimWarpId (ReduceOpHelper &helper, Value &warpId, Location &loc,
223263 ConversionPatternRewriter &rewriter) const {
@@ -231,11 +271,20 @@ struct ReduceOpConversion
231271 // a way to properly delinearize warpId in the slice case
232272 if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
233273 auto parentLayout = sliceLayout.getParent ();
274+ SmallVector<unsigned > dims = {sliceLayout.getDim ()};
275+ while (auto parentSliceLayout =
276+ mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
277+ dims.push_back (parentSliceLayout.getDim ());
278+ parentLayout = parentSliceLayout.getParent ();
279+ }
280+
234281 auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA (parentLayout);
235282 auto parentOrder = triton::gpu::getWarpOrder (parentLayout);
236283 multiDimWarpId =
237284 delinearize (rewriter, loc, warpId, parentWarpsPerCTA, parentOrder);
238- multiDimWarpId.erase (multiDimWarpId.begin () + sliceLayout.getDim ());
285+ for (unsigned dim : llvm::reverse (dims)) {
286+ multiDimWarpId.erase (multiDimWarpId.begin () + dim);
287+ }
239288 } else {
240289 SmallVector<unsigned > warpsPerCTA =
241290 triton::gpu::getWarpsPerCTA (srcLayout);
@@ -263,11 +312,8 @@ struct ReduceOpConversion
263312 unsigned axis = op.getAxis ();
264313 auto smemShape = helper.getScratchRepShape ();
265314
266- auto threadsPerWarp =
267- triton::gpu::getThreadsPerWarpWithUniqueData (srcLayout, srcShape);
268- auto order = getThreadOrder (srcLayout);
269315 SmallVector<Value> multiDimLaneId =
270- delinearize (rewriter, loc, laneId, threadsPerWarp, order );
316+ getMultiDimLaneId (helper, laneId, loc, rewriter );
271317 Value laneIdAxis = multiDimLaneId[axis];
272318 Value zero = i32_val (0 );
273319 Value laneZero = icmp_eq (laneIdAxis, zero);
0 commit comments