File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
src/plugins/intel_gpu/src/graph/impls/ocl_v2/sdpa Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change 1010#include " paged_attention_opt.hpp"
1111
1212#include " intel_gpu/graph/kernel_impl_params.hpp"
13+ #include " openvino/core/type/float16.hpp"
1314#include " intel_gpu/primitives/scaled_dot_product_attention.hpp"
1415#include " ocl_v2/utils/jitter.hpp"
1516#include " scaled_dot_product_attention_inst.h"
@@ -1153,10 +1154,15 @@ JitConstants SDPAMicroGenerator::get_jit_constants(const kernel_impl_params& par
11531154 jit.make (" KV_GROUP_SIZE" , Q_num_heads_dim / K_num_heads_dim);
11541155
11551156 if (d_full) {
1156- if (ldq % 4 == 0 )
1157+ const auto sg_size = get_subgroup_size (device_info.arch );
1158+ constexpr size_t packed_elems_per_uint = sizeof (uint32_t ) / sizeof (ov::float16);
1159+ constexpr size_t max_block_elems = 16 ; // max 16 elements per block load/store per item
1160+ const auto q_block_elems = (d_max / packed_elems_per_uint) / sg_size;
1161+ if (ldq % 4 == 0 && q_block_elems <= max_block_elems)
11571162 jit.make (" BLOCK_Q" , 1 );
11581163 // TODO: Causes accuracy drop for static SD model. Enable back once the issue is resolved
1159- // if (lda % 4 == 0 && v_full)
1164+ // const auto a_block_elems = static_cast<size_t>(gemm_vs.getSetting("sg_tile_m")) / sg_size;
1165+ // if (lda % 4 == 0 && v_full && a_block_elems <= max_block_elems)
11601166 // jit.make("BLOCK_A", 1);
11611167 jit.make (" REMAINDER_Q" , !q_full);
11621168 } else if (device_info.arch >= gpu_arch::xe_hpc) {
You can’t perform that action at this time.
0 commit comments