Skip to content

Commit be86bed

Browse files
committed
Store logit output in static tensor
1 parent 6a22702 commit be86bed

File tree

1 file changed

+54
-30
lines changed

1 file changed

+54
-30
lines changed

megatron/core/inference/text_generation_controllers/text_generation_controller.py

Lines changed: 54 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ def _init_dynamic_sampling_tensors(self):
109109
"""Initialize tensors needed for dynamic sampling."""
110110
context = self.inference_wrapped_model.inference_context
111111
max_requests = context.max_requests
112+
if context.materialize_only_last_token_logits:
113+
max_logits = max_requests
114+
else:
115+
max_logits = context.max_tokens
112116

113117
# Callback to get request IDs that should be marked as finished due to stop words
114118
self._get_stop_word_finished_ids_callback = None
@@ -117,6 +121,15 @@ def _init_dynamic_sampling_tensors(self):
117121
logits_dtype = self.inference_wrapped_model.config.params_dtype
118122

119123
self._sampling_backend = "torch"
124+
self._enable_cuda_graph = False
125+
126+
# Initialize bookkeeping tensors.
127+
if self._enable_cuda_graph:
128+
self._all_logits_cuda = torch.empty(
129+
(1, max_logits, self.vocab_size), dtype=logits_dtype, device=device
130+
)
131+
else:
132+
self._all_logits_cuda = None
120133
self._sampled_tokens_cuda = torch.empty(max_requests, dtype=torch.int64, device=device)
121134
# Speculative tokens tensor will be allocated later when num_speculative_tokens is set by the engine
122135
self._accepted_tokens_per_request = None
@@ -596,7 +609,7 @@ def _dynamic_step_context_init(
596609
else:
597610
return context.current_input_and_position_ids()
598611

599-
def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor) -> Tensor:
612+
def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor):
600613
"""Forward step the model to get logits for dynamic batching.
601614
602615
This also handles logits-broadcasting for pipeline parallelism.
@@ -607,6 +620,11 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
607620
"""
608621
context = self.inference_wrapped_model.inference_context
609622
active_request_count = context.total_request_count - context.paused_request_count
623+
logits_seq_len = (
624+
active_request_count
625+
if context.materialize_only_last_token_logits
626+
else context.padded_active_token_count
627+
)
610628

611629
with torch.inference_mode():
612630
logits = self.inference_wrapped_model.run_one_forward_step(
@@ -619,6 +637,12 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
619637
# will be computed serially after verification to ensure they are
620638
# conditioned on verified tokens only.
621639

640+
assert logits_seq_len == (
641+
active_request_count
642+
if context.materialize_only_last_token_logits
643+
else input_ids.shape[1]
644+
)
645+
622646
if self.model_is_pipeline_parallel:
623647
if context.config.materialize_only_last_token_logits:
624648
logits_seq_len = active_request_count
@@ -636,7 +660,11 @@ def _dynamic_step_forward_logits(self, input_ids: Tensor, position_ids: Tensor)
636660
pp_group=self.pp_group,
637661
)
638662

639-
return logits
663+
# Copy logits to contiguous buffer.
664+
if self._enable_cuda_graph:
665+
self._all_logits_cuda[:, :logits_seq_len, :].copy_(logits)
666+
else:
667+
self._all_logits_cuda = logits
640668

641669
def _dynamic_step_sample_bookkeeping(self):
642670
"""Perform bookkeeping necessary to sample logits for dynamic batching."""
@@ -1053,7 +1081,7 @@ def _verify_speculative_tokens(
10531081

10541082
return last_one_indices, accepted_tokens_mask, input_tokens_required
10551083

1056-
def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_ids: Tensor):
1084+
def _dynamic_step_sample_logits_and_verify_tokens(self, input_ids: Tensor):
10571085
"""
10581086
Sample tokens from logits for dynamic batching with speculative tokens and verify the tokens.
10591087
"""
@@ -1069,6 +1097,7 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_id
10691097
num_decode_requests = active_request_count - num_prefill_requests
10701098

10711099
# Get the logit indices for tokens that need sampling.
1100+
logits = self._all_logits_cuda
10721101
required_logit_indices = self._get_required_logit_indices(
10731102
request_in_prefill_status_tensor,
10741103
request_query_lengths,
@@ -1132,24 +1161,22 @@ def _dynamic_step_sample_logits_and_verify_tokens(self, logits: Tensor, input_id
11321161
dim=1
11331162
)
11341163

1135-
def _dynamic_step_sample_logits(self, logits: Tensor):
1136-
"""Sample tokens from logits for dynamic batching.
1137-
1138-
Args:
1139-
logits (Tensor): The logits from the forward pass.
1140-
"""
1164+
def _dynamic_step_sample_logits(self):
1165+
"""Sample tokens from logits for dynamic batching."""
11411166
# TODO(ksanthanam): Evaluate whether it makes more sense to sample on 1 rank
11421167
# and then broadcast the sampled tokens rather than broadcasting the raw logits.
11431168

11441169
# Last token logits.
11451170
context = self.inference_wrapped_model.inference_context
1171+
active_request_count = context.total_request_count - context.paused_request_count
1172+
11461173
if context.config.materialize_only_last_token_logits:
11471174
# When materialize_only_last_token_logits is true, last_token_logits is
11481175
# already called in the forward pass of GPT.
1149-
required_token_logits = logits.squeeze(0)
1176+
required_token_logits = self._all_logits_cuda.squeeze(0)[:active_request_count, :]
11501177
else:
11511178
# todo : Should do verification here and get approrpiate las token logits
1152-
required_token_logits = context.last_token_logits(logits)
1179+
required_token_logits = context.last_token_logits(self._all_logits_cuda)
11531180

11541181
if self._sampling_backend == "torch":
11551182
# Concatenate the outputs once to prevent repeated small writes.
@@ -1247,19 +1274,24 @@ def _router_record_bookkeeping(self) -> Optional[Dict[int, Tensor]]:
12471274

12481275
return routing_indices_per_request
12491276

1250-
def _dynamic_step_calculate_log_probs(self, logits: Tensor) -> Optional[Tensor]:
1277+
def _dynamic_step_calculate_log_probs(self) -> Optional[Tensor]:
12511278
"""Calculate log probs from logits."""
12521279
context = self.inference_wrapped_model.inference_context
12531280
active_request_count = context.total_request_count - context.paused_request_count
1281+
logits_seq_len = (
1282+
active_request_count
1283+
if context.materialize_only_last_token_logits
1284+
else context.padded_active_token_count
1285+
)
12541286

12551287
return context.calculate_log_probs(
1256-
logits,
1288+
self._all_logits_cuda[:, :logits_seq_len, :],
12571289
self._sampled_tokens_cuda[:active_request_count],
12581290
only_last_token_logits=context.config.materialize_only_last_token_logits,
12591291
)
12601292

12611293
def _dynamic_step_calculate_log_probs_speculative(
1262-
self, logits: Tensor
1294+
self,
12631295
) -> Tuple[List[List[float]], Tensor]:
12641296
"""Calculate log probs from logits for speculative decoding.
12651297
@@ -1271,9 +1303,6 @@ def _dynamic_step_calculate_log_probs_speculative(
12711303
- log_prob(accepted_token[j]) comes from logits at position j
12721304
- log_prob(newly_sampled_token) comes from logits at position accepted_count
12731305
1274-
Args:
1275-
logits (Tensor): The main model logits [1, seq_len, vocab_size].
1276-
12771306
Returns:
12781307
Tuple of (log_probs_list, log_probs_tensor):
12791308
log_probs_list: List of lists, one per active request, containing
@@ -1291,7 +1320,7 @@ def _dynamic_step_calculate_log_probs_speculative(
12911320
num_prefill_requests = request_in_prefill_status_tensor.sum().item()
12921321
num_decode_requests = active_request_count - num_prefill_requests
12931322

1294-
logits_squeezed = logits.squeeze(0).float()
1323+
logits_squeezed = self._all_logits_cuda.squeeze(0).float()
12951324
log_probs_tensor = F.log_softmax(logits_squeezed[: context.active_token_count], dim=-1)
12961325

12971326
log_probs_list_decode = []
@@ -1449,12 +1478,11 @@ def _dynamic_step_calculate_top_n_logprobs_speculative(
14491478
return top_n_results if top_n_results else None
14501479

14511480
def _dynamic_step_calculate_top_n_logprobs(
1452-
self, logits: Tensor, log_probs_tensor: Optional[Tensor] = None
1481+
self, log_probs_tensor: Optional[Tensor] = None
14531482
) -> Optional[Dict[int, List[Tuple[Tensor, Tensor]]]]:
14541483
"""Calculate top-n log probs from logits for dynamic batching.
14551484
14561485
Args:
1457-
logits (Tensor): The logits to compute top-n log probs from.
14581486
log_probs_tensor (Optional[Tensor]): Pre-computed log probabilities tensor.
14591487
If provided, avoids recomputing log_softmax. Should be the tensor
14601488
returned by calculate_log_probs.
@@ -1743,7 +1771,7 @@ async def async_generate_output_tokens_dynamic_batch(
17431771

17441772
# Forward pass produces only base logits. When speculative decoding is
17451773
# active, MTP logits are computed serially after verification.
1746-
logits = self._dynamic_step_forward_logits(input_ids, position_ids)
1774+
self._dynamic_step_forward_logits(input_ids, position_ids)
17471775

17481776
# Commit Mamba intermediate states before update_requests, which
17491777
# may swap request indices. The Python lists tracking EOS block IDs
@@ -1769,7 +1797,7 @@ async def async_generate_output_tokens_dynamic_batch(
17691797

17701798
if self.num_speculative_tokens > 0:
17711799
# Phase 1: Verify speculative tokens using base logits only.
1772-
self._dynamic_step_sample_logits_and_verify_tokens(logits, input_ids)
1800+
self._dynamic_step_sample_logits_and_verify_tokens(input_ids)
17731801
# Phase 2: Rewind KV cache for rejected tokens.
17741802
self._rewind_kv_cache()
17751803

@@ -1781,25 +1809,21 @@ async def async_generate_output_tokens_dynamic_batch(
17811809
# Phase 3: Compute MTP serially with correct (verified) inputs.
17821810
self._compute_serial_mtp_and_sample()
17831811
else:
1784-
self._dynamic_step_sample_logits(logits)
1812+
self._dynamic_step_sample_logits()
17851813

17861814
log_probs = None
17871815
top_n_logprobs = None
17881816
if return_log_probs or return_top_n_logprobs:
17891817
if self.num_speculative_tokens > 0:
1790-
log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs_speculative(
1791-
logits
1792-
)
1818+
log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs_speculative()
17931819
if return_top_n_logprobs:
17941820
top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs_speculative(
17951821
log_probs_tensor
17961822
)
17971823
else:
1798-
log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs(logits)
1824+
log_probs, log_probs_tensor = self._dynamic_step_calculate_log_probs()
17991825
if return_top_n_logprobs:
1800-
top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs(
1801-
logits, log_probs_tensor
1802-
)
1826+
top_n_logprobs = self._dynamic_step_calculate_top_n_logprobs(log_probs_tensor)
18031827

18041828
if skip_bookkeeping:
18051829
request_bookkeeping = {}

0 commit comments

Comments
 (0)