@@ -43,12 +43,26 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc,
4343 // / Optimize reduction with DPAS-encoded input.
4444 // /
4545 // / This optimization reshapes and converts input tensor layouts to split the
46- // / reduction in three equivalent ones:
46+ // / reduction in three equivalent ones.
4747 // /
4848 // / This only works if the number of items for a given thread across dimension
4949 // / 0 and the execution size are equal to the sub-group size.
5050 // /
51- // / First, we go from a DPAS layout to an equivalent blocked layout as follows:
51+ // / We first want to reshape the input tensor to obtain a tensor with an
52+ // / equivalent encoding in terms of how elements are distributed across the
53+ // / device, but with more dimensions across the reduction axis. This way, we
54+ // / will be able to split the reduction in three steps:
55+ // /
56+ // / 1. Reduce within the work-item
57+ // / 2. Convert layout for better locality
58+ // / 3. Reduce within the sub-group and work-group
59+ // /
60+ // / Step 1 may involve more than one dimension depending on the input encoding
61+ // / (2 in this case). After step 1, each thread will hold a single element
62+ // / across the reduction axis dimension, so step 2 will be cheaper.
63+ // /
64+ // / For step 1, we first go from a DPAS layout to an equivalent blocked layout
65+ // / as follows:
5266 // /
5367 // / DPAS:
5468 // / ```
@@ -105,8 +119,9 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc,
105119 // / | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
106120 // / v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
107121 // / ```
108- // / After reshaping and layout conversion, we can get to the actual layout
109- // / optimization we wanted to achieve:
122+ // /
123+ // / Now on with step 2: After reshaping and layout conversion, we can get to
124+ // / the actual layout optimization we wanted to achieve:
110125 // / Blocked (#triton_gpu.blocked<{sizePerThread = [1, repCluster[0]*repeatCount], threadsPerWarp = [executionSize, 1], warpsPerCTA = [warpsPerCTA[0], warpsPerCTA[1]], order = [1, 0]}>):
111126 // / ```
112127 // / warpsPerCTA[1]
@@ -118,8 +133,8 @@ static Value createReshapeForReduction(PatternRewriter &rewriter, Location loc,
118133 // / threadsPerWarp[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[0]
119134 // / | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 |
120135 // / ```
121- // / And reducing on dimension 1 and converting the layout to the original one
122- // / leads to the same output as the original operation.
136+ // / And on with step 3, reducing on dimension 1 and converting the layout to
137+ // / the original one leads to the same output as the original operation.
123138// clang-format on
124139struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
125140 using OpRewritePattern<ReduceOp>::OpRewritePattern;
0 commit comments