Skip to content

Commit 6a22702

Browse files
committed
Move context tensor padding into dedicated method
1 parent 6f9b3b8 commit 6a22702

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1006,6 +1006,21 @@ def build_active_slices(self, batch_size: int):
10061006
graph_scratch_space = torch.cumsum(self.active_request_query_lengths[:batch_size], dim=0)
10071007
self.active_request_last_token_idxs[:batch_size].copy_(graph_scratch_space - 1)
10081008

1009+
def pad_active_slices(self):
1010+
"""Pad the active slices of specific tensors."""
1011+
# Some tensors need to be padded at the token level.
1012+
padding_token_slice = slice(self.active_token_count, self.padded_active_token_count)
1013+
1014+
self.token_to_block_idx[padding_token_slice] = self.kv_block_allocator.dummy_block_idx
1015+
self.token_to_local_position_within_kv_block[padding_token_slice] = 0
1016+
self.token_to_position_in_request[padding_token_slice] = 0
1017+
1018+
# Other tensors need to be padded at the request level.
1019+
padding_request_slice = slice(
1020+
self.total_request_count - self.paused_request_count,
1021+
self.padded_active_request_count,
1022+
)
1023+
10091024
def append_key_value_cache(self, layer_number: int, key: Tensor, value: Tensor) -> None:
10101025
"""Append to KV cache.
10111026
@@ -1620,23 +1635,6 @@ def initialize_attention_state(
16201635
prefill_req_count=padded_prefill_req_count,
16211636
decode_req_count=padded_decode_req_count,
16221637
)
1623-
self.padded_active_token_count = self.padded_batch_dimensions.token_count
1624-
self.padded_active_request_count = self.padded_batch_dimensions.req_count
1625-
self.padding_slice = slice(self.active_token_count, self.padded_active_token_count)
1626-
1627-
self.build_active_slices(self.padded_active_request_count)
1628-
batch_size = self.total_request_count - self.paused_request_count
1629-
1630-
# Update token position indexes.
1631-
self.token_to_block_idx[self.active_token_count : self.padded_active_token_count] = (
1632-
self.kv_block_allocator.dummy_block_idx
1633-
)
1634-
self.token_to_local_position_within_kv_block[
1635-
self.active_token_count : self.padded_active_token_count
1636-
] = 0
1637-
self.token_to_position_in_request[
1638-
self.active_token_count : self.padded_active_token_count
1639-
] = 0
16401638

16411639
self.active_attn_metadata = (
16421640
self.graph_attn_metadata # type: ignore[assignment]
@@ -1657,6 +1655,14 @@ def initialize_attention_state(
16571655
decode_req_count=adjusted_decode_req_count,
16581656
)
16591657

1658+
self.padded_active_token_count = self.padded_batch_dimensions.token_count
1659+
self.padded_active_request_count = self.padded_batch_dimensions.req_count
1660+
self.padding_slice = slice(self.active_token_count, self.padded_active_token_count)
1661+
1662+
self.build_active_slices(self.padded_active_request_count)
1663+
self.pad_active_slices()
1664+
1665+
batch_size = self.total_request_count - self.paused_request_count
16601666
assert self.active_attn_metadata is not None
16611667
self.active_attn_metadata["mha_metadata"].update(
16621668
request_query_lengths=self.active_request_query_lengths[:batch_size],

0 commit comments

Comments
 (0)