@@ -147,8 +147,11 @@ module @flash_attention attributes {gpu.container_module} {
147
147
%zero_dpas = vector.shape_cast %zero : vector <128 xf32 > to vector <8 x16 xf32 >
148
148
149
149
// softmax scaling
150
- %qk_scale_8 = spirv.CompositeConstruct %sm_scale , %sm_scale , %sm_scale , %sm_scale , %sm_scale , %sm_scale , %sm_scale , %sm_scale : (f32 , f32 , f32 , f32 , f32 , f32 , f32 , f32 ) -> vector <8 xf32 >
151
- %qk_scale_16 = spirv.CompositeConstruct %sm_scale , %sm_scale , %sm_scale , %sm_scale , %sm_scale , %sm_scale , %sm_scale , %sm_scale ,%sm_scale , %sm_scale , %sm_scale , %sm_scale ,%sm_scale , %sm_scale , %sm_scale , %sm_scale : (f32 , f32 , f32 , f32 ,f32 , f32 , f32 , f32 ,f32 , f32 , f32 , f32 ,f32 , f32 , f32 , f32 ) -> vector <16 xf32 >
150
+ // %qk_scale_8 = spirv.CompositeConstruct %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale : (f32, f32, f32, f32, f32, f32, f32, f32) -> vector<8xf32>
151
+ // %qk_scale_16 = spirv.CompositeConstruct %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale, %sm_scale,%sm_scale, %sm_scale, %sm_scale, %sm_scale,%sm_scale, %sm_scale, %sm_scale, %sm_scale : (f32, f32, f32, f32,f32, f32, f32, f32,f32, f32, f32, f32,f32, f32, f32, f32 ) -> vector<16xf32>
152
+ // FIXME: value 0.5 is hard coded. need to take it from %sm_scale
153
+ %qk_scale_8 = arith.constant dense <0.5 > : vector <8 xf32 >
154
+ %qk_scale_16 = arith.constant dense <0.5 > : vector <16 xf32 >
152
155
%qk_scale_8x1 = vector.shape_cast %qk_scale_8 : vector <8 xf32 > to vector <8 x1 xf32 >
153
156
%qk_scale_1x16 = vector.shape_cast %qk_scale_16 : vector <16 xf32 > to vector <1 x16 xf32 >
154
157
%qk_scale_8x16 = vector.shuffle %qk_scale_1x16 , %qk_scale_1x16 [0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 ] : vector <1 x16 xf32 >, vector <1 x16 xf32 >
0 commit comments