@@ -443,6 +443,27 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh
443443 ENABLE_DROPOUT : tl .constexpr , RETURN_ENCODED_SOFTMAX : tl .constexpr , USE_ALIBI : tl .constexpr ,
444444 INT8 : tl .constexpr , USE_P_SCALE : tl .constexpr , INT8_KV : tl .constexpr ):
445445
446+ tl .assume (stride_qz >= 0 )
447+ tl .assume (stride_qh >= 0 )
448+ tl .assume (stride_qm >= 0 )
449+ tl .assume (stride_qk >= 0 )
450+ tl .assume (stride_kz >= 0 )
451+ tl .assume (stride_kh >= 0 )
452+ tl .assume (stride_kn >= 0 )
453+ tl .assume (stride_kk >= 0 )
454+ tl .assume (stride_bz >= 0 )
455+ tl .assume (stride_bh >= 0 )
456+ tl .assume (stride_bm >= 0 )
457+ tl .assume (stride_bn >= 0 )
458+ tl .assume (stride_vz >= 0 )
459+ tl .assume (stride_vh >= 0 )
460+ tl .assume (stride_vk >= 0 )
461+ tl .assume (stride_vn >= 0 )
462+ tl .assume (stride_oz >= 0 )
463+ tl .assume (stride_oh >= 0 )
464+ tl .assume (stride_om >= 0 )
465+ tl .assume (stride_on >= 0 )
466+
446467 if PERSISTENT : # if persistent, kernel loops over multiple tiles
447468 NUM_WG = NUM_CU * GRID_CU_MULTIP # number of workgroups launched
448469 num_tiles_per_head = tl .cdiv (MAX_SEQLENS_Q , BLOCK_M ) # the number of work units (tiles) of a single head
0 commit comments