Skip to content

Commit 39bba63

Browse files
authored
[TRTLLM-4983] feat: enable overlap scheduler between draft forwards (NVIDIA#4802)
Signed-off-by: Fanrong Li <[email protected]>
1 parent 5a01ba5 commit 39bba63

File tree

9 files changed

+248
-237
lines changed

9 files changed

+248
-237
lines changed

tensorrt_llm/_torch/models/modeling_speculative.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def forward(
247247
hidden_states: Optional[torch.Tensor] = None,
248248
**kwargs,
249249
) -> torch.Tensor:
250+
hidden_states = self.apply_eagle3_fc(spec_metadata.get_hidden_states())
250251
output, _ = self.model(
251252
input_ids=input_ids,
252253
attn_metadata=attn_metadata,

tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py

Lines changed: 1 addition & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(
5757
device=device,
5858
dtype=torch.int32)
5959

60-
self.extra_model_inputs = {}
6160
self.attn_metadata = attn_metadata
6261
self.spec_metadata = spec_metadata
6362
self._output = None
@@ -70,22 +69,7 @@ def capture(
7069
self,
7170
forward_fn: Callable[[Dict[str, Any]], torch.Tensor],
7271
pool: Optional[Tuple[int, int]] = None,
73-
extra_model_inputs: Optional[Dict[str, torch.Tensor]] = None,
7472
) -> Tuple[int, int]:
75-
"""
76-
Captures a CUDA graph by calling forward_fn(inputs),
77-
where inputs is extra_model_inputs + this graph runner's
78-
input_ids, position_ids, spec_metadata and attn_metadata.
79-
80-
Extra model inputs have the following semantics if
81-
the extra input is a tensor (or collection of
82-
tensors). The CUDA graph runner will create a buffer
83-
of the same shape/dtype/device, and subsequent calls to run() will
84-
require this extra model input. Input tensors will be
85-
copied into the buffer that this CUDA graph runner owns.
86-
This implies that these buffers *must* have static shapes for
87-
this CUDA graph's batch size.
88-
"""
8973
self._graph = torch.cuda.CUDAGraph()
9074
inputs = {
9175
"attn_metadata": self.attn_metadata,
@@ -94,11 +78,6 @@ def capture(
9478
"inputs_embeds": None,
9579
"spec_metadata": self.spec_metadata,
9680
}
97-
if extra_model_inputs is not None:
98-
for key, tensor in extra_model_inputs.items():
99-
new_tensor = tensor.clone()
100-
inputs[key] = new_tensor
101-
self.extra_model_inputs[key] = new_tensor
10281

10382
# We have to do warm up runs to initialize PyTorch's
10483
# internal states according to the docs:
@@ -119,11 +98,7 @@ def capture(
11998
def needs_capture(self) -> bool:
12099
return self._output is None
121100

122-
def run(
123-
self,
124-
inputs: Dict[str, Any],
125-
extra_model_inputs: Optional[Dict[str, torch.Tensor]] = None
126-
) -> torch.Tensor:
101+
def run(self, inputs: Dict[str, Any]) -> torch.Tensor:
127102
assert "input_ids" in inputs
128103
assert "position_ids" in inputs
129104
assert "attn_metadata" in inputs
@@ -145,13 +120,6 @@ def run(
145120
self.input_ids[:seqlen].copy_(input_ids)
146121
self.position_ids[:, :seqlen].copy_(position_ids)
147122

148-
if self.extra_model_inputs:
149-
assert extra_model_inputs is not None, "Model was captured with extra model inputs, so extra_model_inputs must be provided to run()"
150-
for key in self.extra_model_inputs:
151-
assert key in extra_model_inputs, f"Graph runner is missing extra input {key}"
152-
dst_tensor = self.extra_model_inputs[key]
153-
dst_tensor.copy_(extra_model_inputs[key])
154-
155123
assert self._output is not None and self._graph is not None
156124
self._graph.replay()
157125
return self._output

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ def __init__(
281281
self.py_rewind_len = 0
282282
self.py_draft_tokens = [] if self.draft_tokens is None else self.draft_tokens
283283
self.py_last_draft_tokens = None
284+
self.py_num_accepted_draft_tokens = 0
284285
self.py_decoding_iter = 0
285286
self.is_attention_dp_dummy = False
286287
self.is_cuda_graph_dummy = False

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 36 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@ def forward(self,
7777
scheduled_requests: ScheduledRequests,
7878
resource_manager: ResourceManager,
7979
new_tensors_device: Optional[SampleStateTensors],
80-
extra_model_inputs: Optional[Dict[str, Any]],
8180
gather_context_logits: bool = False):
8281
raise NotImplementedError
8382

@@ -338,6 +337,7 @@ def __init__(
338337
spec_config: Optional[SpecConfig] = None,
339338
guided_decoding_config: Optional[GuidedDecodingConfig] = None,
340339
lora_config: Optional[LoraConfig] = None,
340+
is_draft_model: bool = False,
341341
):
342342
self.ub_buffers = None
343343
self.batch_size = batch_size
@@ -353,10 +353,7 @@ def __init__(
353353
self.pytorch_backend_config = pytorch_backend_config
354354
self.spec_config = spec_config
355355
self.is_spec_decode = spec_config is not None
356-
# We keep a reference to the last used spec metadata to
357-
# accommodate certain target/draft model use cases. See
358-
# py_executor.py for how this is used.
359-
self.last_spec_metadata = None
356+
self.is_draft_model = is_draft_model
360357

361358
self.in_warmup = False
362359

@@ -530,6 +527,15 @@ def wrapper(self, *args, **kwargs):
530527

531528
return wrapper
532529

530+
@contextlib.contextmanager
531+
def no_cuda_graph(self):
532+
_run_cuda_graphs = self._run_cuda_graphs
533+
self._run_cuda_graphs = False
534+
try:
535+
yield
536+
finally:
537+
self._run_cuda_graphs = _run_cuda_graphs
538+
533539
@with_warmup_flag
534540
def warmup(self, resource_manager: ResourceManager) -> None:
535541
kv_cache_manager = resource_manager.get_resource_manager(
@@ -654,7 +660,7 @@ def get_autotune_warmup_request():
654660
result.context_requests = requests
655661
result.generation_requests = []
656662

657-
return result, _create_extra_inputs(1, maximum_tunable_num_tokens)
663+
return result
658664

659665
@contextlib.contextmanager
660666
def release_batch(result):
@@ -668,29 +674,6 @@ def release_batch(result):
668674
if spec_resource_manager is not None:
669675
spec_resource_manager.free_resources(req)
670676

671-
@contextlib.contextmanager
672-
def no_cuda_graph():
673-
_run_cuda_graphs = self._run_cuda_graphs
674-
self._run_cuda_graphs = False
675-
try:
676-
yield
677-
finally:
678-
self._run_cuda_graphs = _run_cuda_graphs
679-
680-
def _create_extra_inputs(bs, num_tokens_per_request):
681-
if self.spec_config is None:
682-
extra_model_inputs = None
683-
else:
684-
warmup_inputs_creator = getattr(self.model,
685-
"get_warmup_extra_inputs", None)
686-
if callable(warmup_inputs_creator):
687-
extra_model_inputs = warmup_inputs_creator(
688-
bs, num_tokens_per_request)
689-
else:
690-
extra_model_inputs = None
691-
692-
return extra_model_inputs
693-
694677
# TODO: current warmup_request is not suitable for star attention
695678
cp_type = self.mapping.cp_config.get('cp_type', None)
696679
if cp_type == 'star_attention':
@@ -712,7 +695,7 @@ def disable_optimization(backend: Backend):
712695
set_enable_piecewise_cuda_graph_capture_flag(True)
713696

714697
# Disable cuda graph capture here so that we can properly capture it later
715-
with no_cuda_graph():
698+
with self.no_cuda_graph():
716699
available_tokens = kv_cache_manager.get_num_available_tokens(
717700
self.max_draft_len)
718701
warmup_batch_size = [1, self.batch_size // 2]
@@ -733,17 +716,14 @@ def disable_optimization(backend: Backend):
733716
logger.info(
734717
f"Run warmup for batch size={bs}, pure {'context' if num_tokens_per_request > 1 else 'generation'} phase"
735718
)
736-
self.forward(
737-
batch,
738-
new_tensors_device=None,
739-
resource_manager=resource_manager,
740-
extra_model_inputs=_create_extra_inputs(
741-
bs, num_tokens_per_request))
719+
self.forward(batch,
720+
new_tensors_device=None,
721+
resource_manager=resource_manager)
742722
torch.cuda.synchronize()
743723

744724
if self.pytorch_backend_config.autotuner_enabled:
745-
with no_cuda_graph(), autotune():
746-
result, extra_model_inputs = get_autotune_warmup_request()
725+
with self.no_cuda_graph(), autotune():
726+
result = get_autotune_warmup_request()
747727
with release_batch(result) as batch:
748728
if batch is None:
749729
# No KV cache space!
@@ -753,8 +733,7 @@ def disable_optimization(backend: Backend):
753733
f"Run autotuning warmup for batch size={1}")
754734
self.forward(batch,
755735
new_tensors_device=None,
756-
resource_manager=resource_manager,
757-
extra_model_inputs=extra_model_inputs)
736+
resource_manager=resource_manager)
758737
torch.cuda.synchronize()
759738

760739
logger.info(f"Autotuner Cache size after warmup " +
@@ -783,12 +762,11 @@ def disable_optimization(backend: Backend):
783762
)
784763
self.forward(batch,
785764
new_tensors_device=None,
786-
resource_manager=resource_manager,
787-
extra_model_inputs=_create_extra_inputs(bs, 1))
765+
resource_manager=resource_manager)
788766
torch.cuda.synchronize()
789767

790768
if self._torch_compile_piecewise_cuda_graph:
791-
with no_cuda_graph():
769+
with self.no_cuda_graph():
792770
with release_batch(
793771
get_torch_compile_warmup_request(1,
794772
bs)) as batch:
@@ -797,17 +775,12 @@ def disable_optimization(backend: Backend):
797775
)
798776

799777
for _ in range(3):
800-
self.forward(
801-
batch,
802-
new_tensors_device=None,
803-
resource_manager=resource_manager,
804-
extra_model_inputs=_create_extra_inputs(
805-
1, bs))
806-
self.forward(
807-
batch,
808-
new_tensors_device=None,
809-
resource_manager=resource_manager,
810-
extra_model_inputs=_create_extra_inputs(1, bs))
778+
self.forward(batch,
779+
new_tensors_device=None,
780+
resource_manager=resource_manager)
781+
self.forward(batch,
782+
new_tensors_device=None,
783+
resource_manager=resource_manager)
811784
torch.cuda.synchronize()
812785
gc.collect()
813786
torch.cuda.empty_cache()
@@ -851,15 +824,17 @@ def _set_up_spec_metadata(
851824
self.spec_config,
852825
self.batch_size,
853826
max_num_tokens=self.max_num_tokens,
854-
spec_resource_manager=spec_resource_manager)
827+
spec_resource_manager=spec_resource_manager,
828+
is_draft_model=self.is_draft_model)
855829

856830
if self.spec_metadata is not None:
857831
return self.spec_metadata
858832
self.spec_metadata = get_spec_metadata(
859833
self.spec_config,
860834
self.batch_size,
861835
max_num_tokens=self.max_num_tokens,
862-
spec_resource_manager=spec_resource_manager)
836+
spec_resource_manager=spec_resource_manager,
837+
is_draft_model=self.is_draft_model)
863838
return self.spec_metadata
864839

865840
def _get_padded_batch(self, scheduled_requests: ScheduledRequests,
@@ -1155,7 +1130,6 @@ def _preprocess_inputs(self, inputs: Dict[str, Any]):
11551130
inputs['attn_metadata'].kv_lens_cuda[
11561131
num_ctx_requests:num_seqs] += (
11571132
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
1158-
11591133
return inputs
11601134

11611135
def _prepare_tp_inputs(
@@ -1476,6 +1450,7 @@ def _prepare_tp_inputs(
14761450
lora_params = self._get_lora_params_from_requests(
14771451
scheduled_requests, attn_metadata)
14781452

1453+
# Prepare inputs
14791454
inputs = {
14801455
'attn_metadata': attn_metadata,
14811456
'input_ids': self.input_ids_cuda[:total_num_tokens],
@@ -2027,7 +2002,6 @@ def forward(self,
20272002
scheduled_requests: ScheduledRequests,
20282003
resource_manager: ResourceManager,
20292004
new_tensors_device: Optional[SampleStateTensors] = None,
2030-
extra_model_inputs: Optional[Dict[str, Any]] = None,
20312005
gather_context_logits: bool = False):
20322006

20332007
kv_cache_manager = resource_manager.get_resource_manager(
@@ -2055,9 +2029,6 @@ def forward(self,
20552029
if kv_cache_manager is None:
20562030
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
20572031
scheduled_requests, attn_metadata, spec_metadata)
2058-
if extra_model_inputs is not None:
2059-
inputs.update(extra_model_inputs)
2060-
self.last_spec_metadata = spec_metadata
20612032

20622033
with MoeLoadBalancerIterContext(moe_load_balancer):
20632034
return self._forward_step(inputs, gather_ids,
@@ -2081,9 +2052,6 @@ def forward(self,
20812052
attn_metadata,
20822053
spec_metadata,
20832054
new_tensors_device)
2084-
if extra_model_inputs is not None:
2085-
inputs.update(extra_model_inputs)
2086-
self.last_spec_metadata = spec_metadata
20872055

20882056
self.iter_counter += 1
20892057

@@ -2104,16 +2072,15 @@ def capture_forward_fn(inputs: Dict[str, Any]):
21042072
pool = maybe_graph.capture(
21052073
capture_forward_fn,
21062074
self._cuda_graph_mem_pool,
2107-
extra_model_inputs,
21082075
)
21092076
self._cuda_graph_mem_pool = pool
21102077

21112078
# here we don't need to use context since cuda graph capture didn't run kernel.
21122079
# maybe we need a cleaner way to do this.
2113-
outputs = maybe_graph.run(inputs, extra_model_inputs)
2080+
outputs = maybe_graph.run(inputs)
21142081
else:
21152082
with MoeLoadBalancerIterContext(moe_load_balancer):
2116-
outputs = maybe_graph.run(inputs, extra_model_inputs)
2083+
outputs = maybe_graph.run(inputs)
21172084

21182085
# Note: To overlap the CPU and GPU computation as much as possible,
21192086
# guided_decoder.build should be called immediately after the launch of the single step;
@@ -2148,6 +2115,8 @@ def _forward_step(self,
21482115
gather_ids: Optional[torch.Tensor],
21492116
gather_context_logits: bool = False) -> Dict[str, Any]:
21502117
inputs = self._preprocess_inputs(inputs)
2118+
if inputs.get('spec_metadata', None):
2119+
gather_ids = inputs['spec_metadata'].gather_ids
21512120
if self.without_logits:
21522121
outputs = self.model_forward(**inputs)
21532122
return outputs

0 commit comments

Comments
 (0)