diff --git a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py index 37d6989408..d5c05b5ebc 100644 --- a/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py +++ b/fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py @@ -3846,7 +3846,7 @@ def _should_skip_config(block_k, matrix_instr_nonkdim): """Skip config if BLOCK_K=64 and matrix_instr_nonkdim=16 on GFX95+""" try: return ( - block_k == 64 + block_k <= 64 and matrix_instr_nonkdim == 16 and torch.version.hip is not None and torch.cuda.get_device_capability() >= (9, 5)