Skip to content
Closed
6 changes: 4 additions & 2 deletions benchmarks/kernels/benchmark_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,10 @@ def get_rocm_tuning_space(use_fp16):
# small search space, no pruning required
# bypassLDS: block_n/num_warps=16 for perf
block_m_range = [16, 32, 64, 128, 256]
block_n_range = [128] if use_fp16 else [64]
block_n_range = [128] if use_fp16 else [128]
block_k_range = [128] if use_fp16 else [256]

num_warps_range = [8] if use_fp16 else [4]
num_warps_range = [8] if use_fp16 else [8]
group_m_range = [1]
# For now we see better perf with num_stages=0 for all gemm configs we care
# But keep this explicit so that we do not forget we may need to set it to
Expand Down Expand Up @@ -211,6 +211,8 @@ def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
keys, values = zip(*param_ranges.items())
for config_values in product(*values):
config = dict(zip(keys, config_values))
assert config['num_warps'] == config['BLOCK_SIZE_N'] // 16, \
"num_warps should be equal to BLOCK_SIZE_N divided by 16"
configs.append(config)
return configs

Expand Down
10 changes: 6 additions & 4 deletions benchmarks/kernels/benchmark_paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
create_kv_caches_with_random)

NUM_BLOCKS = 1024 * 1024
PARTITION_SIZE = 512
PARTITION_SIZE = 256


@torch.inference_mode()
Expand Down Expand Up @@ -101,7 +101,7 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
start_time = time.perf_counter()

# Using default kv_scale
k_scale = v_scale = 1.0
k_scale = v_scale = 0.1

for _ in range(num_iters):
if version == "v1":
Expand Down Expand Up @@ -161,6 +161,8 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
kv_cache_dtype,
k_scale,
v_scale,
None,
PARTITION_SIZE
)
else:
raise ValueError(f"Invalid version: {version}")
Expand All @@ -174,13 +176,13 @@ def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
# Warmup.
print("Warming up...")
run_benchmark = run_cuda_benchmark
run_benchmark(num_iters=3, profile=False)
run_benchmark(num_iters=500, profile=False)

# Benchmark.
if do_profile:
latency = run_benchmark(num_iters=1, profile=True)
else:
latency = run_benchmark(num_iters=1000, profile=False)
latency = run_benchmark(num_iters=10000, profile=False)
print(f"Kernel running time: {latency * 1000000:.3f} us")


Expand Down
1 change: 1 addition & 0 deletions benchmarks/kernels/tune_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export FUSED_MOE_PERSISTENT=1
export VLLM_MOE_PADDING=128
export VLLM_MOE_SHUFFLE=1
export TRITON_HIP_USE_NEW_STREAM_PIPELINE=1
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7

## ---- Mixtral fp8 tuning ---- ##

Expand Down
177 changes: 173 additions & 4 deletions csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
@@ -1,16 +1,185 @@
#include "common.cuh"
#include "dispatch_utils.h"

#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>

#ifndef USE_ROCM
#include <cmath>

#include "cuda_compat.h"
#include "dispatch_utils.h"

#if defined(USE_CUDA_FP8_FORMAT)
#include <cub/util_type.cuh>
#include <cub/cub.cuh>
#else
#include <hipcub/util_type.hpp>
#include <hipcub/hipcub.hpp>
#endif

#if defined(USE_CUDA_FP8_FORMAT)
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX =
std::numeric_limits<FP8_TYPE>::max();
#else
#include "amd/hip_float8.h"
using FP8_TYPE = c10::Float8_e4m3fnuz;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr auto FP8_E4M3_MAX = 224.0f;
#endif

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
float old;
old = (value >= 0)
? __int_as_float(atomicMax((int*)addr, __float_as_int(value)))
: __uint_as_float(
atomicMin((unsigned int*)addr, __float_as_uint(value)));

return old;
}

template <bool is_scale_inverted>
__device__ __forceinline__ FP8_TYPE scaled_fp8_conversion(float const val,
float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}

float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
#if defined(USE_CUDA_FP8_FORMAT)
return static_cast<c10::Float8_e4m3fn>(r);
#else
// Use hardware cvt instruction for fp8 on rocm
return c10::Float8_e4m3fnuz(hip_fp8(r).data,
c10::Float8_e4m3fnuz::from_bits());
#endif
}

// Compute the absolute maximum m of the input tensor and store
// m / float8_e4m3::max() in *scale. Each thread block performs a
// reduction tree and the memory in scale is atomically updated.
// So to get the right answer, *scale needs to be initialized to
// a value <= 0.0 and we need to wait for all thread blocks to
// finish before consuming *scale.
template <typename scalar_t>
__global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;

// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;

__syncthreads();

// Now perform parallel reduction within the thread block
int ib = blockDim.x / 2;
while (ib != 0) {
if (threadIdx.x < ib && cache[threadIdx.x + ib] > cache[threadIdx.x]) {
cache[threadIdx.x] = cache[threadIdx.x + ib];
}
__syncthreads();
ib /= 2;
}
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if (threadIdx.x == 0) {
atomicMaxFloat(scale, cache[0] / FP8_E4M3_MAX);
}
}

template <typename scalar_t>
struct __align__(8) vec4_t {
scalar_t x;
scalar_t y;
scalar_t z;
scalar_t w;
};

typedef struct __align__(4) {
FP8_TYPE x;
FP8_TYPE y;
FP8_TYPE z;
FP8_TYPE w;
}
float8x4_t;

template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);

int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;

#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}

return absmax_val;
}

template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(FP8_TYPE* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);

int64_t const num_vec_elems = num_elems >> 2;

#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;

out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}

// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}

template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(FP8_TYPE* __restrict__ out,
const scalar_t* __restrict__ input,
Expand Down
Loading