55from typing import Any , Callable , List , MutableMapping , Optional , Tuple , Union
66import torch
77import fms .utils .spyre .paged # noqa
8+ from aiu_fms_testing_utils .utils import get_pad_size
89
910
1011def adjust_inputs_to_batch (input_ids : torch .Tensor , ** extra_kwargs ):
@@ -226,6 +227,12 @@ def generate(
226227 # left_padded_prompt_mask - empty_slots + context_lengths
227228 current_tkv_mask = torch .fill (context_lengths , input_ids .shape [1 ])
228229
230+ # if using chunked prefill, reserve a pad block
231+ # reserving a pad block is required as writes to pad are done in parallel and could corrupt the real blocks
232+ if prefill_chunk_size > 0 :
233+ pad_block_id = block_numbers .pop (0 )
234+ pad_slots = [(pad_block_id * BLOCK_SIZE ) + pos_i for pos_i in range (BLOCK_SIZE )]
235+
229236 slot_mapping = []
230237 block_table = []
231238 # each sequence has the possibility of a different tkv, so loop over that
@@ -244,6 +251,7 @@ def generate(
244251 slot_mapping_i .append (slot )
245252 slot_mapping .append (slot_mapping_i )
246253 block_table .append (block_table_i )
254+
247255 kwargs ["current_tkv_mask" ] = None
248256 kwargs ["left_padded_prompt_mask" ] = None
249257 kwargs ["use_cache" ] = use_cache
@@ -300,64 +308,110 @@ def generate(
300308 last_n_tokens = kwargs .get ("last_n_tokens" , 0 )
301309
302310 if prefill_chunk_size > 0 :
303- left_padded_prompt_mask_seq_chunk = None
311+ required_extra_pads = (
312+ get_pad_size (current_tkv .item (), prefill_chunk_size )
313+ - current_tkv .item ()
314+ )
315+ left_padded_prompt_mask_seq_chunk = (
316+ (kwargs ["position_ids" ][seq_i ][- current_tkv .item () :] == 0 ).sum (
317+ dim = 0
318+ )
319+ - 1
320+ + required_extra_pads
321+ )
322+ left_padded_prompt_mask_seq_chunk = (
323+ left_padded_prompt_mask_seq_chunk .unsqueeze (0 )
324+ )
325+ block_seq_left_padding = required_extra_pads // BLOCK_SIZE
326+
304327 # Chunked prefill
305328 for chunk_j in range (math .ceil (current_tkv / prefill_chunk_size )):
306- chunk_start = - current_tkv + chunk_j * prefill_chunk_size
307- chunk_end = - current_tkv + min (
308- (chunk_j + 1 ) * prefill_chunk_size , current_tkv
309- )
329+ # chunk_start and chunk_end are the index mappings from the original sequence
330+ if chunk_j == 0 :
331+ chunk_start = 0
332+ chunk_end = prefill_chunk_size - required_extra_pads
333+ else :
334+ required_extra_pads = 0
335+ chunk_start = chunk_end
336+ chunk_end += prefill_chunk_size
337+
338+ input_ids_seq_chunk = input_ids [seq_i ][- current_tkv :][
339+ chunk_start :chunk_end
340+ ]
341+ slot_mapping_seq_chunk = slot_mapping [seq_i ][- current_tkv :][
342+ chunk_start :chunk_end
343+ ]
344+ position_ids_seq_chunk = kwargs ["position_ids" ][seq_i ][
345+ - current_tkv :
346+ ][chunk_start :chunk_end ]
347+
348+ # add the extra required padding to chunk
349+ if required_extra_pads > 0 :
350+ input_ids_seq_chunk = torch .cat (
351+ (
352+ torch .zeros (
353+ required_extra_pads ,
354+ dtype = torch .int64 ,
355+ device = input_ids_seq_chunk .device ,
356+ ),
357+ input_ids_seq_chunk ,
358+ )
359+ )
360+ slot_mapping_seq_chunk = (
361+ pad_slots * (required_extra_pads // BLOCK_SIZE )
362+ + slot_mapping_seq_chunk
363+ )
364+ position_ids_seq_chunk = torch .cat (
365+ (
366+ torch .zeros (
367+ required_extra_pads ,
368+ dtype = torch .int64 ,
369+ device = position_ids_seq_chunk .device ,
370+ ),
371+ position_ids_seq_chunk ,
372+ )
373+ )
374+
375+ input_ids_seq_chunk = input_ids_seq_chunk .unsqueeze (0 ).clone ()
310376
311- ids_length = input_ids [seq_i ].shape [0 ]
312- input_ids_seq_chunk = (
313- input_ids [seq_i ][
314- chunk_start + ids_length : chunk_end + ids_length
315- ]
316- .unsqueeze (0 )
317- .clone ()
318- )
319- assert input_ids_seq_chunk .size (1 ) == prefill_chunk_size , (
320- f"prefill chunk size was not equal to the chunk size. Found { input_ids_seq_chunk .size (0 )} "
321- )
322- slots_length = len (slot_mapping [seq_i ])
323377 slot_mapping_seq_chunk = (
324378 torch .tensor (
325- slot_mapping [seq_i ][
326- chunk_start + slots_length : chunk_end
327- + slots_length
328- ],
379+ slot_mapping_seq_chunk ,
329380 dtype = torch .int64 ,
330381 )
331382 .unsqueeze (0 )
332383 .clone ()
333384 )
334- pids_length = kwargs [ "position_ids" ][ seq_i ]. shape [ 0 ]
335- position_ids_seq_chunk = (
336- kwargs [ "position_ids" ][ seq_i ][
337- chunk_start + pids_length : chunk_end + pids_length
338- ]
339- . unsqueeze ( 0 )
340- . clone ()
385+
386+ position_ids_seq_chunk = position_ids_seq_chunk . unsqueeze (
387+ 0
388+ ). clone ()
389+
390+ assert input_ids_seq_chunk . size ( 1 ) == prefill_chunk_size , (
391+ f"prefill chunk size was not equal to the chunk size for input_ids. Found { input_ids_seq_chunk . size ( 0 ) } "
341392 )
342393
343- # This view will result in a discontiguous tensor (creates a new graph during compile)
344- # For this reason, we must explicitly make contiguous
345- if left_padded_prompt_mask_seq_chunk is None :
346- left_padded_prompt_mask_seq_chunk = (
347- position_ids_seq_chunk == 0
348- ).sum (dim = 1 ) - 1
349- current_tkv_mask_seq_chunk = torch .min (
350- torch .tensor (
351- (chunk_j + 1 ) * prefill_chunk_size , dtype = torch .int64
352- ),
353- current_tkv ,
394+ assert slot_mapping_seq_chunk .size (1 ) == prefill_chunk_size , (
395+ f"prefill chunk size was not equal to the chunk size for slot_mapping. Found { slot_mapping_seq_chunk .size (0 )} "
396+ )
397+
398+ assert position_ids_seq_chunk .size (1 ) == prefill_chunk_size , (
399+ f"prefill chunk size was not equal to the chunk size for position_ids. Found { position_ids_seq_chunk .size (0 )} "
400+ )
401+
402+ current_tkv_mask_seq_chunk = torch .tensor (
403+ (chunk_j + 1 ) * prefill_chunk_size , dtype = torch .int64
354404 ).unsqueeze (0 )
355405
356- table_length = len ( block_table [ seq_i ])
357- block_start = - current_tkv // BLOCK_SIZE + table_length
358- block_end = chunk_end // BLOCK_SIZE + table_length
406+ block_end = chunk_end // BLOCK_SIZE
407+ # length of padding or index until padding has occured in block table
408+ block_pad_len = ( input_ids . shape [ 1 ] - current_tkv ) // BLOCK_SIZE
359409 block_table_seq_chunk = torch .tensor (
360- block_table [seq_i ][block_start :block_end ], dtype = torch .int64
410+ [pad_block_id ] * (block_seq_left_padding )
411+ + block_table [seq_i ][
412+ block_pad_len : block_pad_len + block_end
413+ ],
414+ dtype = torch .int64 ,
361415 ).unsqueeze (0 )
362416
363417 chunked_kwargs = {
0 commit comments