Skip to content

Commit 4c40380

Browse files
committed
semi-working cudagraphs
Signed-off-by: Bill Nell <[email protected]>
1 parent 48ba146 commit 4c40380

File tree

14 files changed

+135
-69
lines changed

14 files changed

+135
-69
lines changed

csrc/dispatch_utils.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,18 @@
6565
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
6666
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
6767

68+
#define VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(...) \
69+
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
70+
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
71+
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
72+
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
73+
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
74+
AT_DISPATCH_CASE(at::ScalarType::UInt16, __VA_ARGS__) \
75+
AT_DISPATCH_CASE(at::ScalarType::UInt32, __VA_ARGS__) \
76+
AT_DISPATCH_CASE(at::ScalarType::UInt64, __VA_ARGS__)
77+
6878
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
6979
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
80+
81+
#define VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(TYPE, NAME, ...) \
82+
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_AND_UNSIGNED_TYPES(__VA_ARGS__))

csrc/moe/moe_align_sum_kernels.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
326326
}
327327

328328
if (use_global_memory) {
329-
VLLM_DISPATCH_INTEGRAL_TYPES(
329+
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
330330
topk_ids.scalar_type(), "moe_align_block_size_global_mem_kernel", [&] {
331331
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
332332
// tensors
@@ -351,7 +351,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
351351
cumsum_buffer.data_ptr<int32_t>());
352352
});
353353
} else if (use_i16) {
354-
VLLM_DISPATCH_INTEGRAL_TYPES(
354+
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
355355
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
356356
// set dynamic shared mem
357357
auto kernel =
@@ -366,7 +366,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
366366
topk_ids.numel());
367367
});
368368
} else {
369-
VLLM_DISPATCH_INTEGRAL_TYPES(
369+
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
370370
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
371371
auto kernel =
372372
vllm::moe::moe_align_block_size_kernel<scalar_t, int32_t>;
@@ -391,7 +391,7 @@ void sgl_moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
391391
TORCH_CHECK(num_experts == 256,
392392
"sgl_moe_align_block_size kernel only supports deepseek v3.");
393393

394-
VLLM_DISPATCH_INTEGRAL_TYPES(
394+
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
395395
topk_ids.scalar_type(), "sgl_moe_align_block_size_kernel", [&] {
396396
// calc needed amount of shared mem for `cumsum` tensors
397397
auto options_int =

csrc/moe/topk_softmax_kernels.cu

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,17 @@ __launch_bounds__(TPB) __global__
108108
}
109109
}
110110

111-
template <int TPB>
112-
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
113-
int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
111+
template <int TPB, typename IndType>
112+
__launch_bounds__(TPB) __global__ void moeTopK(
113+
const float* inputs_after_softmax,
114+
const bool* finished,
115+
float* output,
116+
IndType* indices,
117+
int* source_rows,
118+
const int num_experts,
119+
const int k,
120+
const int start_expert,
121+
const int end_expert)
114122
{
115123

116124
using cub_kvp = cub::KeyValuePair<int, float>;
@@ -182,9 +190,9 @@ __launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax
182190
2) This implementation assumes k is small, but will work for any k.
183191
*/
184192

185-
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
193+
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG, typename IndType>
186194
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
187-
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
195+
void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, IndType* indices,
188196
int* source_rows, const int k, const int start_expert, const int end_expert)
189197
{
190198
// We begin by enforcing compile time assertions and setting up compile time constants.
@@ -397,8 +405,8 @@ struct TopkConstants
397405
};
398406
} // namespace detail
399407

400-
template <int EXPERTS, int WARPS_PER_TB>
401-
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
408+
template <int EXPERTS, int WARPS_PER_TB, typename IndType>
409+
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, IndType* indices,
402410
int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
403411
{
404412
static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
@@ -421,10 +429,11 @@ void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, f
421429
token_expert_indices, num_tokens, topk, 0, num_experts, \
422430
stream);
423431

432+
template <typename IndType>
424433
void topkGatingSoftmaxKernelLauncher(
425434
const float* gating_output,
426435
float* topk_weights,
427-
int* topk_indicies,
436+
IndType* topk_indicies,
428437
int* token_expert_indices,
429438
float* softmax_workspace,
430439
const int num_tokens,
@@ -493,14 +502,32 @@ void topk_softmax(
493502
const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
494503
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
495504
torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
496-
vllm::moe::topkGatingSoftmaxKernelLauncher(
497-
gating_output.data_ptr<float>(),
498-
topk_weights.data_ptr<float>(),
499-
topk_indices.data_ptr<int>(),
500-
token_expert_indices.data_ptr<int>(),
501-
softmax_workspace.data_ptr<float>(),
502-
num_tokens,
503-
num_experts,
504-
topk,
505-
stream);
505+
506+
if(topk_indices.scalar_type() == at::ScalarType::Int)
507+
{
508+
vllm::moe::topkGatingSoftmaxKernelLauncher(
509+
gating_output.data_ptr<float>(),
510+
topk_weights.data_ptr<float>(),
511+
topk_indices.data_ptr<int>(),
512+
token_expert_indices.data_ptr<int>(),
513+
softmax_workspace.data_ptr<float>(),
514+
num_tokens,
515+
num_experts,
516+
topk,
517+
stream);
518+
}
519+
else
520+
{
521+
assert(topk_indices.scalar_type() == at::ScalarType::UInt32);
522+
vllm::moe::topkGatingSoftmaxKernelLauncher(
523+
gating_output.data_ptr<float>(),
524+
topk_weights.data_ptr<float>(),
525+
topk_indices.data_ptr<uint32_t>(),
526+
token_expert_indices.data_ptr<int>(),
527+
softmax_workspace.data_ptr<float>(),
528+
num_tokens,
529+
num_experts,
530+
topk,
531+
stream);
532+
}
506533
}

examples/offline_inference/data_parallel.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from time import sleep
3232

3333
from vllm import LLM, SamplingParams
34+
from vllm.config import CompilationConfig
3435
from vllm.utils import get_open_port
3536

3637

@@ -65,11 +66,14 @@ def parse_args():
6566
type=int,
6667
default=0,
6768
help="Master node port")
69+
parser.add_argument("--enforce-eager",
70+
action='store_true',
71+
help="Enforce eager mode execution.")
6872
return parser.parse_args()
6973

7074

7175
def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
72-
dp_master_port, GPUs_per_dp_rank):
76+
dp_master_port, GPUs_per_dp_rank, enforce_eager):
7377
os.environ["VLLM_DP_RANK"] = str(global_dp_rank)
7478
os.environ["VLLM_DP_RANK_LOCAL"] = str(local_dp_rank)
7579
os.environ["VLLM_DP_SIZE"] = str(dp_size)
@@ -109,10 +113,14 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
109113
max_tokens=[16, 20][global_dp_rank % 2])
110114

111115
# Create an LLM.
116+
cconfig = CompilationConfig(
117+
level=0,
118+
)
112119
llm = LLM(model=model,
113120
tensor_parallel_size=GPUs_per_dp_rank,
114-
enforce_eager=True,
115-
enable_expert_parallel=True)
121+
enforce_eager=enforce_eager,
122+
enable_expert_parallel=True,
123+
compilation_config=cconfig)
116124
outputs = llm.generate(prompts, sampling_params)
117125
# Print the outputs.
118126
for i, output in enumerate(outputs):
@@ -155,7 +163,7 @@ def main(model, dp_size, local_dp_rank, global_dp_rank, dp_master_ip,
155163
proc = Process(target=main,
156164
args=(args.model, dp_size, local_dp_rank,
157165
global_dp_rank, dp_master_ip, dp_master_port,
158-
tp_size))
166+
tp_size, args.enforce_eager))
159167
proc.start()
160168
procs.append(proc)
161169
exit_code = 0

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ build-backend = "setuptools.build_meta"
1515
[project]
1616
name = "vllm"
1717
authors = [{name = "vLLM Team"}]
18-
license = "Apache-2.0"
19-
license-files = ["LICENSE"]
18+
#license = "Apache-2.0"
19+
#license-files = ["LICENSE"]
2020
readme = "README.md"
2121
description = "A high-throughput and memory-efficient inference and serving engine for LLMs"
2222
classifiers = [

vllm/compilation/compiler_interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,9 @@ def _get_shape_env() -> AlwaysHitShapeEnv:
326326
# compilation cache.
327327
if not envs.VLLM_DISABLE_COMPILE_CACHE:
328328
assert hash_str is not None, (
329-
"failed to get the hash of the compiled graph")
329+
f"failed to get the hash of the compiled graph: {file_path}")
330330
assert file_path is not None, (
331-
"failed to get the file path of the compiled graph")
331+
"failed to get the file path of the compiled graph: {file_path}")
332332
return compiled_graph, (hash_str, file_path)
333333

334334
def load(self,

vllm/distributed/utils.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,11 @@ def stateless_destroy_torch_distributed_process_group(
360360
Destroy ProcessGroup returned by
361361
stateless_init_torch_distributed_process_group().
362362
"""
363-
# Lazy import for non-CUDA backends.
364-
from torch.distributed.distributed_c10d import _shutdown_backend
365-
_shutdown_backend(pg)
363+
# TODO: pytorch < 2.7?
364+
if False:
365+
# Lazy import for non-CUDA backends.
366+
from torch.distributed.distributed_c10d import _shutdown_backend
367+
_shutdown_backend(pg)
368+
else:
369+
pg.shutdown()
366370
_unregister_process_group(pg.group_name)

vllm/model_executor/layers/fused_moe/fused_batched_moe.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -577,11 +577,11 @@ def workspace_shapes(
577577
topk: int,
578578
num_experts: int,
579579
) -> Tuple[int, int, torch.dtype]:
580+
assert a.dim() == 2
580581
max_num_tokens = a.shape[
581-
1] if self.max_num_tokens is None else self.max_num_tokens
582-
# TODO: *2 is a hack
583-
workspace13 = num_experts * max_num_tokens * K * topk * 2
584-
workspace2 = max_num_tokens * N
582+
0] if self.max_num_tokens is None else self.max_num_tokens
583+
workspace13 = num_experts * max_num_tokens * max(K, N)
584+
workspace2 = max_num_tokens * (N // 2)
585585
return (workspace13, workspace2, a.dtype)
586586

587587
def apply(
@@ -605,6 +605,7 @@ def apply(
605605
) -> torch.Tensor:
606606
assert hidden_states.dim() == 3
607607
assert expert_num_tokens is not None
608+
hidden_dim = hidden_states.shape[-1]
608609

609610
if self.max_num_tokens is None:
610611
max_num_tokens = hidden_states.shape[1]
@@ -613,13 +614,13 @@ def apply(
613614

614615
num_experts = global_num_experts
615616
out = _resize_cache(workspace13,
616-
(num_experts, max_num_tokens, w2.shape[1]))
617+
(num_experts, max_num_tokens, hidden_dim))
617618
num_local_experts = expert_num_tokens.numel()
618619

619620
for expert in range(num_local_experts):
620621
num = expert_num_tokens[expert]
621-
assert num <= max_num_tokens, f"{num}, {max_num_tokens}"
622-
if num > 0:
622+
#assert num <= max_num_tokens, f"{num}, {max_num_tokens}"
623+
if True or num > 0: # CUDAGRAPH unfriendly?
623624
tmp = _resize_cache(workspace2, (num, w1.shape[1] // 2))
624625
self.activation(
625626
activation, tmp,
@@ -660,8 +661,9 @@ def workspace_shapes(
660661
topk: int,
661662
num_experts: int,
662663
) -> Tuple[int, int, torch.dtype]:
664+
assert a.dim() == 2
663665
max_num_tokens = a.shape[
664-
1] if self.max_num_tokens is None else self.max_num_tokens
666+
0] if self.max_num_tokens is None else self.max_num_tokens
665667
workspace13 = num_experts * max_num_tokens * max(K, N)
666668
workspace2 = num_experts * max_num_tokens * (N // 2)
667669
return (workspace13, workspace2, a.dtype)
@@ -685,9 +687,6 @@ def apply(
685687
workspace2: torch.Tensor,
686688
expert_num_tokens: Optional[torch.Tensor],
687689
) -> torch.Tensor:
688-
689-
num_tokens = topk_ids.size(0)
690-
691690
# Check constraints.
692691
if self.use_int4_w4a16:
693692
assert hidden_states.shape[-1] // 2 == w1.shape[
@@ -705,6 +704,7 @@ def apply(
705704
torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn
706705
]
707706

707+
# TODO: num_tokens -> max_num_tokens?
708708
E, num_tokens, N, K, top_k_num = mk._moe_problem_size(
709709
hidden_states, w1, w2, topk_ids)
710710

vllm/model_executor/layers/fused_moe/fused_moe.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,6 +870,7 @@ def fused_topk(
870870
gating_output: torch.Tensor,
871871
topk: int,
872872
renormalize: bool,
873+
indices_type: torch.dtype = torch.int32,
873874
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
874875
assert hidden_states.shape[0] == gating_output.shape[0], (
875876
"Number of tokens mismatch")
@@ -882,7 +883,7 @@ def fused_topk(
882883
device=hidden_states.device)
883884
topk_ids = torch.empty(M,
884885
topk,
885-
dtype=torch.int32,
886+
dtype=indices_type,
886887
device=hidden_states.device)
887888
token_expert_indicies = torch.empty(M,
888889
topk,

0 commit comments

Comments
 (0)