Skip to content

Commit d21dbc9

Browse files
committed
add ck kernel invocations
1 parent 7953b38 commit d21dbc9

File tree

43 files changed

+3405
-2
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+3405
-2
lines changed

CMakeLists.txt

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,64 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
539539
ARCHITECTURES ${VLLM_GPU_ARCHES}
540540
USE_SABI 3
541541
WITH_SOABI)
542+
543+
#
544+
# _fp8gemm_C extension
545+
#
546+
set(VLLM_FP8GEMM_EXT_SRC
547+
"csrc/fbgemm_fp8_rowwise/torch_bindings.cpp"
548+
"csrc/fbgemm_fp8_rowwise/fp8_rowwise_gemm.cu"
549+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu"
550+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu"
551+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu"
552+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu"
553+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu"
554+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu"
555+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.cu"
556+
# "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2_16.split_k.cu"
557+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu"
558+
# "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2_16.split_k.cu"
559+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.cu"
560+
# "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_16.split_k.cu"
561+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu"
562+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.cu"
563+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
564+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.cu"
565+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu"
566+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.cu"
567+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
568+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
569+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x192x128x128_16x16_6x4_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3.cu"
570+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3.cu"
571+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu"
572+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
573+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.cu"
574+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu"
575+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu"
576+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu"
577+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
578+
# "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1_16.split_k.cu"
579+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu"
580+
# "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.cu"
581+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu"
582+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu"
583+
"csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu"
584+
)
585+
586+
find_package(composable_kernel REQUIRED)
587+
588+
set(fp8_supported_archs "gfx940" "gfx941" "gfx942")
589+
define_gpu_extension_target(
590+
_fp8gemm_C
591+
DESTINATION vllm
592+
LANGUAGE ${VLLM_GPU_LANG}
593+
SOURCES ${VLLM_FP8GEMM_EXT_SRC}
594+
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
595+
ARCHITECTURES ${fp8_supported_archs}
596+
USE_SABI 3
597+
WITH_SOABI
598+
# LIBRARIES "composable_kernel"
599+
)
542600
endif()
543601

544602
# vllm-flash-attn currently only supported on CUDA

0 commit comments

Comments
 (0)