diff --git a/fbgemm_gpu/FbgemmGpu.cmake b/fbgemm_gpu/FbgemmGpu.cmake index a033688232..0c1eb21385 100644 --- a/fbgemm_gpu/FbgemmGpu.cmake +++ b/fbgemm_gpu/FbgemmGpu.cmake @@ -184,5 +184,6 @@ gpu_cpp_library( fbgemm_gpu_tbe_cache fbgemm_gpu_tbe_optimizers fbgemm_gpu_tbe_utils + fbgemm_gpu_config DESTINATION fbgemm_gpu) diff --git a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py index 8f8fd6f495..4264919625 100644 --- a/fbgemm_gpu/fbgemm_gpu/config/feature_list.py +++ b/fbgemm_gpu/fbgemm_gpu/config/feature_list.py @@ -60,6 +60,9 @@ def foo(): # Enable bounds_check_indices_v2 BOUNDS_CHECK_INDICES_V2 = auto() + # disable fp8 quant vectorization + DISABLE_FP8_QUANT_VECTORIZATION = auto() + # Enable TBE input parameters extraction TBE_REPORT_INPUT_PARAMS = auto() diff --git a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h index 617df07c05..9c472ba620 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h +++ b/fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h @@ -55,13 +55,14 @@ namespace fbgemm_gpu::config { /// UI. /// /// For OSS: The environment variable will be evaluated as f"FBGEMM_{ENUM}" -#define ENUMERATE_ALL_FEATURE_FLAGS \ - X(TBE_V2) \ - X(TBE_ENSEMBLE_ROWWISE_ADAGRAD) \ - X(TBE_ANNOTATE_KINETO_TRACE) \ - X(TBE_ROCM_INFERENCE_PACKED_BAGS) \ - X(TBE_ROCM_HIP_BACKWARD_KERNEL) \ - X(BOUNDS_CHECK_INDICES_V2) \ +#define ENUMERATE_ALL_FEATURE_FLAGS \ + X(TBE_V2) \ + X(TBE_ENSEMBLE_ROWWISE_ADAGRAD) \ + X(TBE_ANNOTATE_KINETO_TRACE) \ + X(TBE_ROCM_INFERENCE_PACKED_BAGS) \ + X(TBE_ROCM_HIP_BACKWARD_KERNEL) \ + X(BOUNDS_CHECK_INDICES_V2) \ + X(DISABLE_FP8_QUANT_VECTORIZATION) \ X(TBE_REPORT_INPUT_PARAMS) // X(EXAMPLE_FEATURE_FLAG) diff --git a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu index 24e4362bf7..36acc0c907 100644 --- a/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu +++ b/fbgemm_gpu/src/quantize_ops/quantize_fp8_rowwise.cu @@ -7,6 +7,7 @@ */ #include "common.cuh" +#include "fbgemm_gpu/config/feature_gates.h" using Tensor = at::Tensor; @@ -157,6 +158,125 @@ __global__ inline void _compute_FP8_quantize_cuda_kernel( } } +template +struct VectorSizeTraits { + // Default to 4 elements for most types (16 bytes for float) + static constexpr int value = 4; +}; + +// Specialization for half (float16) +template <> +struct VectorSizeTraits { + // 8 elements for half precision (16 bytes total) + static constexpr int value = 8; +}; + +// Specialization for __nv_bfloat16 +template <> +struct VectorSizeTraits { + // 8 elements for bfloat16 precision (16 bytes total) + static constexpr int value = 8; +}; + +// aligned vector generates vectorized load/store on CUDA (copy-pasted from +// MemoryAccess.cuh) +template ::value> +struct alignas(sizeof(scalar_t) * vec_size) aligned_vector { + scalar_t val[vec_size]; +}; + +template +#ifndef USE_ROCM +__global__ __attribute__((maxrregcount(32))) inline void +#else +__global__ inline void +#endif +_compute_FP8_quantize_cuda_vectorized_kernel( + const pta::PackedTensorAccessor64 input, + const int64_t nrows, + const int64_t ncols, + pta::PackedTensorAccessor64 output, + const bool forward) { + // Calculate global row index with 2D thread blocks + const int64_t gx = blockIdx.x * blockDim.x + threadIdx.x; + const int64_t thread_idx = blockIdx.y * blockDim.y + threadIdx.y; + static constexpr int vec_size = VectorSizeTraits::value; + // Early return if row is out of bounds + if (gx >= nrows || (thread_idx * vec_size) >= ncols) { + return; + } + + int ebit = forward ? 4 : 5; + int bias = forward ? 15 : 31; + float max_pos = forward ? 0.9375 : 0.875; + + // Calculate output width + const auto ncols_aligned = (ncols + 4 - 1) / 4 * 4; + const auto output_columns = ncols_aligned + 2 * sizeof(float); + + // Calculate base offsets for the current row + const int64_t input_row_offset = gx * ncols; + const int64_t output_row_offset = gx * output_columns; + + // Calculate the position where the scale values are stored + const int64_t scale_offset = output_row_offset + ncols_aligned; + const float scale_value = reinterpret_cast(&output[scale_offset])[0]; + + const int64_t vector_blocks = ncols / vec_size; + + using vec_t = aligned_vector; + using vec_i = aligned_vector; + + const int64_t col_idx = thread_idx * vec_size; + + // The if else here garantee the kernel works for aligned/misaligned + // cases. When ncols is not multiple of vec_size, then we can't dereference + // the pointer, and we access one by one, this trigger multiple trips to + // global memory, but is still faster than the original kernel. + if ((col_idx + (vec_size - 1) < ncols) && ((ncols % vec_size) == 0)) { + // Load vec_size elements - handle both aligned and unaligned cases + // correctly + const vec_t input_row = + *reinterpret_cast(&input[input_row_offset + col_idx]); + + vec_i* output_row = + reinterpret_cast(&output[output_row_offset + col_idx]); + + // // Create temporary vector to enable vectorized store + vec_i temp_output; +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + temp_output.val[i] = float_to_hfp8( + to_float(input_row.val[i]) * scale_value, ebit, bias, max_pos); + } + *output_row = temp_output; + } else if ((col_idx + (vec_size - 1) < ncols)) { + // correctly + const vec_t* input_row = + reinterpret_cast(&input[input_row_offset + col_idx]); + + vec_i* output_row = + reinterpret_cast(&output[output_row_offset + col_idx]); +#pragma unroll + for (int i = 0; i < vec_size; ++i) { + output_row->val[i] = float_to_hfp8( + to_float(input_row->val[i]) * scale_value, ebit, bias, max_pos); + } + } + + // 2. Process any remaining elements (less than vec_size) with scalar + // operations + const int64_t remaining_start = vector_blocks * vec_size; + for (int64_t col = remaining_start + threadIdx.y; col < ncols; + col += blockDim.y) { + output[output_row_offset + col] = float_to_hfp8( + to_float(input[input_row_offset + col]) * scale_value, + ebit, + bias, + max_pos); + } +} + template __global__ inline void _FP8rowwise_to_float_cuda_kernel( pta::PackedTensorAccessor64 input, @@ -247,13 +367,6 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) { forward); }); } else { - // range_tensor is used to store the range for each embedding row. - // We save max_pos/max_val(rowwise) as row scale to quantize - // unlike INT8, FP8 does not have zero shift - // This will guarantee the numerical match but bring some perf - // regression. - auto range_tensor = at::empty({nrows}, input.options().dtype(at::kFloat)); - { // we need a blockDim.x that is a power of 2 no larger than the warp size // of 32 @@ -289,27 +402,63 @@ Tensor _float_to_FP8rowwise_gpu_t(const Tensor& input, const bool forward) { } { - const int blockDim_x = - std::min(ncols, static_cast(threads_per_block)); - dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); - const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x); - const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); - dim3 gridDim(gridDim_x, gridDim_y); - - FBGEMM_DISPATCH_FLOATING_TYPES( - input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] { - FBGEMM_LAUNCH_KERNEL( - (_compute_FP8_quantize_cuda_kernel), - gridDim, - blockDim, - 0, - at::cuda::getCurrentCUDAStream(), - PTA_B(input_1D, scalar_t, 1, 64), - nrows, - ncols, - PTA_B(output_1D, uint8_t, 1, 64), - forward); - }); + const uintptr_t addr = reinterpret_cast(&input); + + const static bool use_vectorization = + ((addr % 16) == 0) && + !config::is_feature_enabled( + config::FeatureGateName::DISABLE_FP8_QUANT_VECTORIZATION); + + const constexpr int vec_size = VectorSizeTraits::value; + if (use_vectorization) { + const int block_y = 64; + const int blockDim_y = ncols > vec_size ? block_y : 1; + + dim3 blockDim(threads_per_block / blockDim_y, blockDim_y); + const auto gridDim_x = cuda_calc_xblock_count(nrows, blockDim.x); + const auto gridDim_y = cuda_calc_block_count( + (ncols + vec_size - 1) / vec_size, blockDim.y); + dim3 gridDim(gridDim_x, gridDim_y); + + FBGEMM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "_compute_FP8_quantize_cuda_vectorized_kernel", + [&] { + FBGEMM_LAUNCH_KERNEL( + (_compute_FP8_quantize_cuda_vectorized_kernel), + gridDim, + blockDim, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(input_1D, scalar_t, 1, 64), + nrows, + ncols, + PTA_B(output_1D, uint8_t, 1, 64), + forward); + }); + } else { + const int blockDim_x = + std::min(ncols, static_cast(threads_per_block)); + dim3 blockDim(blockDim_x, threads_per_block / blockDim_x); + const auto gridDim_x = cuda_calc_xblock_count(ncols, blockDim.x); + const auto gridDim_y = cuda_calc_block_count(nrows, blockDim.y); + dim3 gridDim(gridDim_x, gridDim_y); + + FBGEMM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "_compute_FP8_quantize_cuda_kernel", [&] { + FBGEMM_LAUNCH_KERNEL( + (_compute_FP8_quantize_cuda_kernel), + gridDim, + blockDim, + 0, + at::cuda::getCurrentCUDAStream(), + PTA_B(input_1D, scalar_t, 1, 64), + nrows, + ncols, + PTA_B(output_1D, uint8_t, 1, 64), + forward); + }); + } } } @@ -358,8 +507,8 @@ Tensor _FP8rowwise_to_float_gpu_t( // to 1, 2, 4, 8, or 16 bytes. Any access (via a variable or a pointer) to // data residing in global memory compiles to a single global memory // instruction if and only if the size of the data type is 1, 2, 4, 8, or 16 - // bytes and the data is naturally aligned (i.e., its address is a multiple of - // that size). + // bytes and the data is naturally aligned (i.e., its address is a multiple + // of that size). auto output_dims = input_sizes.vec(); output_dims[last_dim] = output_columns; const auto output_sdtype = static_cast(output_dtype);