66#q = #xegpu.layout <sg_layout = [8 , 1 ], sg_data = [16 , 64 ], inst_data = [8 , 16 ]>
77#k = #xegpu.layout <sg_layout = [8 , 1 ], sg_data = [16 , 64 ], inst_data = [16 , 16 ]>
88#v = #k
9- #kt = #xegpu.layout <sg_layout = [1 , 8 ], sg_data = [64 , 16 ], inst_data = [16 , 16 ]>
9+ #kt = #xegpu.layout <sg_layout = [1 , 8 ], sg_data = [64 , 16 ], inst_data = [16 , 16 ], order = [ 0 , 1 ] >
1010#k_prefetch = #xegpu.layout <sg_layout = [4 , 2 ], sg_data = [16 , 32 ], inst_data = [16 , 16 ]>
1111#v_prefetch = #k_prefetch
1212#out = #q
13+ #out_t = #xegpu.layout <sg_layout = [1 , 8 ], sg_data = [64 , 16 ], inst_data = [16 , 8 ], order = [0 , 1 ]>
1314#layout_128x1 = #xegpu.layout <sg_layout = [8 , 1 ], sg_data = [16 , 1 ], inst_data = [8 , 1 ]>
14- #layout_128x16 = #xegpu.layout <sg_layout = [8 , 1 ], sg_data = [16 , 16 ], inst_data = [8 , 16 ]>
15+ #layout_128x16 = #xegpu.layout <sg_layout = [8 , 1 ], sg_data = [16 , 16 ], inst_data = [8 , 16 ] >
16+ #layout_128x16_t = #xegpu.layout <sg_layout = [1 , 8 ], sg_data = [16 , 16 ], inst_data = [16 , 8 ], order = [0 , 1 ]>
1517#layout_128 = #xegpu.layout <sg_layout = [8 ], sg_data = [16 ], inst_data = [8 ]>
1618module @flash_attention attributes {gpu.container_module } {
1719 gpu.module @flash_attention_fwd {
@@ -162,7 +164,6 @@ module @flash_attention attributes {gpu.container_module} {
162164 %qk_out_max_t3 = vector.multi_reduction <maximumf >, %qk_out_max_t2 , %minus_inf_128
163165 {layout_result_0 = #xegpu.slice <#layout_128x16 , dims = [1 ]>}
164166 [1 ] : vector <128 x16 xf32 > to vector <128 xf32 >
165- // %qk_out_max = vector.shape_cast %qk_out_max_t3 {layout_result_0 = #layout_128x1} : vector<128xf32> to vector<128x1xf32>
166167
167168 // Scale
168169 %qk_out_max_scaled = arith.mulf %qk_out_max_t3 , %qk_scale_128 {layout_result_0 = #layout_128 } : vector <128 xf32 >
@@ -174,8 +175,8 @@ module @flash_attention attributes {gpu.container_module} {
174175 %qk_out_2_scaled = arith.mulf %qk_out_2 , %qk_scale_128x16 {layout_result_0 = #layout_128x16 } : vector <128 x16 xf32 >
175176 %qk_out_3_scaled = arith.mulf %qk_out_3 , %qk_scale_128x16 {layout_result_0 = #layout_128x16 } : vector <128 x16 xf32 >
176177 // Broadcast m_ij_row to 128x16
177- %m_ij_row_broadcasted0 = vector.shape_cast %m_ij_row {layout_result_0 = #layout_128x1 , layout_operand_0 = #xegpu.slice < #layout_128x1 , dims =[ 1 ]> } : vector <128 xf32 > to vector <128 x 1 x f32 >
178- %m_ij_row_broadcasted = vector.broadcast %m_ij_row_broadcasted0 {layout_result_0 = #layout_128x16 } : vector <128 x 1 x f32 > to vector <128 x16 xf32 >
178+ %m_ij_row_broadcasted0 = vector.broadcast %m_ij_row {layout_result_0 = #layout_128x16_t } : vector <128 xf32 > to vector <16 x 128 x f32 >
179+ %m_ij_row_broadcasted = vector.transpose %m_ij_row_broadcasted0 , [ 1 , 0 ] {layout_result_0 = #layout_128x16 } : vector <16 x 128 x f32 > to vector <128 x16 xf32 >
179180 // Center qk_out by m_ij_row
180181 %qk_out_0_centered = arith.subf %qk_out_0_scaled , %m_ij_row_broadcasted {layout_result_0 = #layout_128x16 } : vector <128 x16 xf32 >
181182 %qk_out_1_centered = arith.subf %qk_out_1_scaled , %m_ij_row_broadcasted {layout_result_0 = #layout_128x16 } : vector <128 x16 xf32 >
@@ -193,16 +194,15 @@ module @flash_attention attributes {gpu.container_module} {
193194 %l_ij_row_t3 = vector.multi_reduction <add >, %l_ij_row_t2 , %zero_128
194195 {layout_result_0 = #xegpu.slice <#layout_128x16 , dims = [1 ]>}
195196 [1 ] : vector <128 x16 xf32 > to vector <128 xf32 >
196- // %l_ij_row = vector.shape_cast %l_ij_row_t3 {layout_result_0 = #layout_128x1} : vector<128xf32> to vector<128x1xf32>
197197 // Compute alpha
198198 %alpha_row_t1 = arith.subf %m_i_row , %m_ij_row {layout_result_0 = #layout_128 } : vector <128 xf32 >
199199 %alpha_row = math.exp %alpha_row_t1 fastmath <fast > {layout_result_0 = #layout_128 } : vector <128 xf32 >
200200 // Update l_i
201201 %l_i_row_new_t1 = arith.mulf %l_i_row , %alpha_row {layout_result_0 = #layout_128 } : vector <128 xf32 >
202202 %l_i_row_new = arith.addf %l_i_row_new_t1 , %l_ij_row_t3 {layout_result_0 = #layout_128 } : vector <128 xf32 >
203203 // Update acc
204- %alpha_row_broadcasted0 = vector.shape_cast %alpha_row {layout_result_0 = #layout_128x1 , layout_operand_0 = #xegpu.slice < #layout_128x1 , dims =[ 1 ]> } : vector <128 xf32 > to vector <128 x 1 x f32 >
205- %alpha_row_broadcasted = vector.broadcast %alpha_row_broadcasted0 {layout_result_0 = #out } : vector <128 x 1 x f32 > to vector <128 x64 xf32 >
204+ %alpha_row_broadcasted0 = vector.broadcast %alpha_row {layout_result_0 = #out_t } : vector <128 xf32 > to vector <64 x 128 x f32 >
205+ %alpha_row_broadcasted = vector.transpose %alpha_row_broadcasted0 , [ 1 , 0 ] {layout_result_0 = #out } : vector <64 x 128 x f32 > to vector <128 x64 xf32 >
206206 %acc_in_updated = arith.mulf %acc_in , %alpha_row_broadcasted {layout_result_0 = #out } : vector <128 x64 xf32 >
207207
208208 // Convert qk_out_tile to DPAS-A precision for P*V computation.
@@ -234,8 +234,8 @@ module @flash_attention attributes {gpu.container_module} {
234234 scf.yield %pv_out_iter3 , %m_ij_row , %l_i_row_new : vector <128 x64 xf32 >, vector <128 xf32 >, vector <128 xf32 >
235235 } {layout_result_0 = #out , layout_result_1 = #layout_128 , layout_result_2 = #layout_128 }// end of inner loop
236236 // Divide acc output by l_i
237- %l_i_row_broadcast0 = vector.shape_cast %result#2 {layout_result_0 = #layout_128x1 , layout_operand_0 = #xegpu.slice < #layout_128x1 , dims =[ 0 ]> } : vector <128 xf32 > to vector <128 x 1 x f32 >
238- %l_i_row_broadcast = vector.broadcast %l_i_row_broadcast0 {layout_result_0 = #out } : vector <128 x 1 x f32 > to vector <128 x64 xf32 >
237+ %l_i_row_broadcast0 = vector.broadcast %result#2 {layout_result_0 = #out_t } : vector <128 xf32 > to vector <64 x 128 x f32 >
238+ %l_i_row_broadcast = vector.transpose %l_i_row_broadcast0 , [ 1 , 0 ] {layout_result_0 = #out } : vector <64 x 128 x f32 > to vector <128 x64 xf32 >
239239 %o_val_final_t = arith.divf %result#0 , %l_i_row_broadcast {layout_result_0 = #out } : vector <128 x64 xf32 >
240240 // Store output tile.
241241 %o_val_final = arith.truncf %o_val_final_t {layout_result_0 = #out } : vector <128 x64 xf32 > to vector <128 x64 xf16 >
0 commit comments