@@ -220,6 +220,46 @@ struct ReduceOpConversion
220220 rewriter.replaceOp (op, results);
221221 }
222222
223+ // For slice layout some ids are duplicated on multiple lanes, so we need to
224+ // handle the delinearization of laneId in a special way. We need to
225+ // generalize this part of the logic to work on any kind of linear layout
226+ // uniformely.
227+ SmallVector<Value>
228+ getMultiDimLaneId (ReduceOpHelper &helper, Value &laneId, Location &loc,
229+ ConversionPatternRewriter &rewriter) const {
230+ auto srcLayout = helper.getSrcLayout ();
231+ auto srcShape = helper.getSrcShape ();
232+ auto order = triton::gpu::getThreadOrder (srcLayout);
233+ SmallVector<Value> multiDimLaneId;
234+
235+ if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
236+ auto parentLayout = sliceLayout.getParent ();
237+ SmallVector<unsigned > dims = {sliceLayout.getDim ()};
238+ while (auto parentSliceLayout =
239+ mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
240+ dims.push_back (parentSliceLayout.getDim ());
241+ parentLayout = parentSliceLayout.getParent ();
242+ }
243+
244+ auto parentThreadsPerWarps = triton::gpu::getThreadsPerWarp (parentLayout);
245+ auto parentOrder = triton::gpu::getThreadOrder (parentLayout);
246+ multiDimLaneId = delinearize (rewriter, loc, laneId, parentThreadsPerWarps,
247+ parentOrder);
248+ for (unsigned dim : llvm::reverse (dims)) {
249+ multiDimLaneId.erase (multiDimLaneId.begin () + dim);
250+ }
251+ } else {
252+ SmallVector<unsigned > threadsPerWarps =
253+ triton::gpu::getThreadsPerWarp (srcLayout);
254+ threadsPerWarps[helper.getAxis ()] =
255+ triton::gpu::getThreadsPerWarpWithUniqueData (
256+ srcLayout, srcShape)[helper.getAxis ()];
257+ multiDimLaneId =
258+ delinearize (rewriter, loc, laneId, threadsPerWarps, order);
259+ }
260+ return multiDimLaneId;
261+ }
262+
223263 SmallVector<Value>
224264 getMultiDimWarpId (ReduceOpHelper &helper, Value &warpId, Location &loc,
225265 ConversionPatternRewriter &rewriter) const {
@@ -233,11 +273,20 @@ struct ReduceOpConversion
233273 // a way to properly delinearize warpId in the slice case
234274 if (auto sliceLayout = mlir::dyn_cast<SliceEncodingAttr>(srcLayout)) {
235275 auto parentLayout = sliceLayout.getParent ();
276+ SmallVector<unsigned > dims = {sliceLayout.getDim ()};
277+ while (auto parentSliceLayout =
278+ mlir::dyn_cast<SliceEncodingAttr>(parentLayout)) {
279+ dims.push_back (parentSliceLayout.getDim ());
280+ parentLayout = parentSliceLayout.getParent ();
281+ }
282+
236283 auto parentWarpsPerCTA = triton::gpu::getWarpsPerCTA (parentLayout);
237284 auto parentOrder = triton::gpu::getWarpOrder (parentLayout);
238285 multiDimWarpId =
239286 delinearize (rewriter, loc, warpId, parentWarpsPerCTA, parentOrder);
240- multiDimWarpId.erase (multiDimWarpId.begin () + sliceLayout.getDim ());
287+ for (unsigned dim : llvm::reverse (dims)) {
288+ multiDimWarpId.erase (multiDimWarpId.begin () + dim);
289+ }
241290 } else {
242291 SmallVector<unsigned > warpsPerCTA =
243292 triton::gpu::getWarpsPerCTA (srcLayout);
@@ -265,11 +314,8 @@ struct ReduceOpConversion
265314 unsigned axis = op.getAxis ();
266315 auto smemShape = helper.getScratchRepShape ();
267316
268- auto threadsPerWarp =
269- triton::gpu::getThreadsPerWarpWithUniqueData (srcLayout, srcShape);
270- auto order = getThreadOrder (srcLayout);
271317 SmallVector<Value> multiDimLaneId =
272- delinearize (rewriter, loc, laneId, threadsPerWarp, order );
318+ getMultiDimLaneId (helper, laneId, loc, rewriter );
273319 Value laneIdAxis = multiDimLaneId[axis];
274320 Value zero = i32_val (0 );
275321 Value laneZero = icmp_eq (laneIdAxis, zero);
0 commit comments