@@ -539,6 +539,64 @@ if(VLLM_GPU_LANG STREQUAL "HIP")
539539 ARCHITECTURES ${VLLM_GPU_ARCHES}
540540 USE_SABI 3
541541 WITH_SOABI)
542+
543+ #
544+ # _fp8gemm_C extension
545+ #
546+ set (VLLM_FP8GEMM_EXT_SRC
547+ "csrc/fbgemm_fp8_rowwise/torch_bindings.cpp"
548+ "csrc/fbgemm_fp8_rowwise/fp8_rowwise_gemm.cu"
549+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu"
550+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu"
551+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu"
552+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu"
553+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu"
554+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu"
555+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.cu"
556+ # "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2_16.split_k.cu"
557+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu"
558+ # "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2_16.split_k.cu"
559+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.cu"
560+ # "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_16.split_k.cu"
561+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu"
562+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.cu"
563+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
564+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.cu"
565+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu"
566+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3.cu"
567+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
568+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
569+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x192x128x128_16x16_6x4_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3.cu"
570+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x192x256x128_16x16_6x8_8x32x1_8x32x1_1x32x1x8_8x8x1_2x2_intrawave_v3.cu"
571+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu"
572+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
573+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.cu"
574+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu"
575+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu"
576+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu"
577+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu"
578+ # "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v1_16.split_k.cu"
579+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu"
580+ # "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.cu"
581+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu"
582+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu"
583+ "csrc/fbgemm_fp8_rowwise/kernels/fp8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu"
584+ )
585+
586+ find_package (composable_kernel REQUIRED)
587+
588+ set (fp8_supported_archs "gfx940" "gfx941" "gfx942" )
589+ define_gpu_extension_target(
590+ _fp8gemm_C
591+ DESTINATION vllm
592+ LANGUAGE ${VLLM_GPU_LANG}
593+ SOURCES ${VLLM_FP8GEMM_EXT_SRC}
594+ COMPILE_FLAGS ${VLLM_GPU_FLAGS}
595+ ARCHITECTURES ${fp8_supported_archs}
596+ USE_SABI 3
597+ WITH_SOABI
598+ # LIBRARIES "composable_kernel"
599+ )
542600endif ()
543601
544602# vllm-flash-attn currently only supported on CUDA
0 commit comments