Skip to content

Commit 75774ce

Browse files
committed
Address comments
1 parent fc22c4a commit 75774ce

File tree

1 file changed

+34
-32
lines changed

1 file changed

+34
-32
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -77,41 +77,41 @@ static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter,
7777
/// ```
7878
/// Blocked (#triton_gpu.blocked<{sizePerThread = [executionSize, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], 1, 1, warpsPerCTA[1], 1], order = [3, 4, 5, 6, 0, 1, 2]}>):
7979
/// ```
80-
/// warpsPerCTA[5]
81-
/// <------------------------------------------------------------------------------->
82-
/// getShape()[4]
83-
/// <---------------------------------->
84-
/// threadsPerWarp[3]
85-
/// <---------------->
86-
/// ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^
87-
/// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
88-
/// | sizePerThread[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
89-
/// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
90-
/// | v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
91-
/// | ..................................................................................|
92-
/// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[2]
93-
/// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
94-
/// size[1] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
95-
/// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
96-
/// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
80+
/// warpsPerCTA[5]
81+
/// <------------------------------------------------------------------------------->
82+
/// getShape()[4]
83+
/// <---------------------------------->
84+
/// threadsPerWarp[3]
85+
/// <---------------->
86+
/// ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^
87+
/// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
88+
/// | sizePerThread[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
89+
/// | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
90+
/// | v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
91+
/// | ..................................................................................|
92+
/// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[2]
93+
/// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
94+
/// getShape()[1] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
95+
/// | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
96+
/// v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
9797
/// ```
9898
/// So we can reduce on dimensions 6 and 4 to get to:
9999
/// ```
100-
/// warpsPerCTA[3]
101-
/// <------------------------------------------------------------------------------->
102-
/// threadsPerWarp[3]
103-
/// <---------------->
104-
/// ^ ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^
105-
/// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
106-
/// | sizePerThread[0] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
107-
/// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
108-
/// | v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
109-
/// | .......................................|
110-
/// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[2]
111-
/// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
112-
/// size[1] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
113-
/// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
114-
/// v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
100+
/// warpsPerCTA[3]
101+
/// <------------------------------------->
102+
/// threadsPerWarp[3]
103+
/// <---------------->
104+
/// ^ ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^
105+
/// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
106+
/// | sizePerThread[0] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
107+
/// | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
108+
/// | v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
109+
/// | .......................................|
110+
/// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[2]
111+
/// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
112+
/// getShape()[1] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
113+
/// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
114+
/// v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
115115
/// ```
116116
///
117117
/// Now on with step 2: After reshaping and layout conversion, we can get to
@@ -128,6 +128,8 @@ static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter,
128128
/// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 |
129129
/// ```
130130
/// And on with step 3, after reducing on dimension 3, we'd get:
131+
/// Blocked (#triton_gpu.blocked<{sizePerThread = [1, 1, 1, executionSize], threadsPerWarp = [executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], warpsPerCTA[1]], order = [3, 0, 1, 2]}>):
132+
/// Sliced (#triton_gpu.sliced<{dim = 3, parent = #blocked}>)
131133
/// ```
132134
/// ^ t0 ^
133135
/// | t1 |

0 commit comments

Comments
 (0)