@@ -77,41 +77,41 @@ static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter,
7777 // / ```
7878 // / Blocked (#triton_gpu.blocked<{sizePerThread = [executionSize, 1, 1, 1, 1, 1, 1], threadsPerWarp = [1, 1, 1, executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], 1, 1, warpsPerCTA[1], 1], order = [3, 4, 5, 6, 0, 1, 2]}>):
7979 // / ```
80- // / warpsPerCTA[5]
81- // / <------------------------------------------------------------------------------->
82- // / getShape()[4]
83- // / <---------------------------------->
84- // / threadsPerWarp[3]
85- // / <---------------->
86- // / ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^
87- // / | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
88- // / | sizePerThread[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
89- // / | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
90- // / | v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
91- // / | ..................................................................................|
92- // / | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[2]
93- // / | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
94- // / size [1] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
95- // / | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
96- // / v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
80+ // / warpsPerCTA[5]
81+ // / <------------------------------------------------------------------------------->
82+ // / getShape()[4]
83+ // / <---------------------------------->
84+ // / threadsPerWarp[3]
85+ // / <---------------->
86+ // / ^ ^ t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn ^
87+ // / | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
88+ // / | sizePerThread[0] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
89+ // / | | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
90+ // / | v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
91+ // / | ..................................................................................|
92+ // / | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn | warpsPerCTA[2]
93+ // / | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
94+ // / getShape() [1] | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
95+ // / | t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
96+ // / v t0 t1 t2 t3 ... tn t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn tn1 tn2 tn3 tn4 ... tnn |
9797 // / ```
9898 // / So we can reduce on dimensions 6 and 4 to get to:
9999 // / ```
100- // / warpsPerCTA[3]
101- // / <------------------------------------------ ------------------------------------->
102- // / threadsPerWarp[3]
103- // / <---------------->
104- // / ^ ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^
105- // / | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
106- // / | sizePerThread[0] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
107- // / | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
108- // / | v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
109- // / | .......................................|
110- // / | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[2]
111- // / | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
112- // / size [1] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
113- // / | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
114- // / v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
100+ // / warpsPerCTA[3]
101+ // / < ------------------------------------->
102+ // / threadsPerWarp[3]
103+ // / <---------------->
104+ // / ^ ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn ^
105+ // / | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
106+ // / | sizePerThread[0] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
107+ // / | | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
108+ // / | v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
109+ // / | .......................................|
110+ // / | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn | warpsPerCTA[2]
111+ // / | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
112+ // / getShape() [1] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
113+ // / | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
114+ // / v t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... tnn |
115115 // / ```
116116 // /
117117 // / Now on with step 2: After reshaping and layout conversion, we can get to
@@ -128,6 +128,8 @@ static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter,
128128 // / | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 |
129129 // / ```
130130 // / And on with step 3, after reducing on dimension 3, we'd get:
131+ // / Blocked (#triton_gpu.blocked<{sizePerThread = [1, 1, 1, executionSize], threadsPerWarp = [executionSize, 1, 1, 1], warpsPerCTA = [1, 1, warpsPerCTA[0], warpsPerCTA[1]], order = [3, 0, 1, 2]}>):
132+ // / Sliced (#triton_gpu.sliced<{dim = 3, parent = #blocked}>)
131133 // / ```
132134 // / ^ t0 ^
133135 // / | t1 |
0 commit comments