Skip to content

Commit 2473e7a

Browse files
committed
Resolve conflicts
Signed-off-by: Rafael Vasquez <[email protected]>
2 parents c2a25ee + fc3a30f commit 2473e7a

File tree

2 files changed

+122
-52
lines changed

2 files changed

+122
-52
lines changed

aiu_fms_testing_utils/utils/paged.py

Lines changed: 97 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Callable, List, MutableMapping, Optional, Tuple, Union
66
import torch
77
import fms.utils.spyre.paged # noqa
8+
from aiu_fms_testing_utils.utils import get_pad_size
89

910

1011
def adjust_inputs_to_batch(input_ids: torch.Tensor, **extra_kwargs):
@@ -226,6 +227,12 @@ def generate(
226227
# left_padded_prompt_mask - empty_slots + context_lengths
227228
current_tkv_mask = torch.fill(context_lengths, input_ids.shape[1])
228229

230+
# if using chunked prefill, reserve a pad block
231+
# reserving a pad block is required as writes to pad are done in parallel and could corrupt the real blocks
232+
if prefill_chunk_size > 0:
233+
pad_block_id = block_numbers.pop(0)
234+
pad_slots = [(pad_block_id * BLOCK_SIZE) + pos_i for pos_i in range(BLOCK_SIZE)]
235+
229236
slot_mapping = []
230237
block_table = []
231238
# each sequence has the possibility of a different tkv, so loop over that
@@ -244,6 +251,7 @@ def generate(
244251
slot_mapping_i.append(slot)
245252
slot_mapping.append(slot_mapping_i)
246253
block_table.append(block_table_i)
254+
247255
kwargs["current_tkv_mask"] = None
248256
kwargs["left_padded_prompt_mask"] = None
249257
kwargs["use_cache"] = use_cache
@@ -300,64 +308,110 @@ def generate(
300308
last_n_tokens = kwargs.get("last_n_tokens", 0)
301309

302310
if prefill_chunk_size > 0:
303-
left_padded_prompt_mask_seq_chunk = None
311+
required_extra_pads = (
312+
get_pad_size(current_tkv.item(), prefill_chunk_size)
313+
- current_tkv.item()
314+
)
315+
left_padded_prompt_mask_seq_chunk = (
316+
(kwargs["position_ids"][seq_i][-current_tkv.item() :] == 0).sum(
317+
dim=0
318+
)
319+
- 1
320+
+ required_extra_pads
321+
)
322+
left_padded_prompt_mask_seq_chunk = (
323+
left_padded_prompt_mask_seq_chunk.unsqueeze(0)
324+
)
325+
block_seq_left_padding = required_extra_pads // BLOCK_SIZE
326+
304327
# Chunked prefill
305328
for chunk_j in range(math.ceil(current_tkv / prefill_chunk_size)):
306-
chunk_start = -current_tkv + chunk_j * prefill_chunk_size
307-
chunk_end = -current_tkv + min(
308-
(chunk_j + 1) * prefill_chunk_size, current_tkv
309-
)
329+
# chunk_start and chunk_end are the index mappings from the original sequence
330+
if chunk_j == 0:
331+
chunk_start = 0
332+
chunk_end = prefill_chunk_size - required_extra_pads
333+
else:
334+
required_extra_pads = 0
335+
chunk_start = chunk_end
336+
chunk_end += prefill_chunk_size
337+
338+
input_ids_seq_chunk = input_ids[seq_i][-current_tkv:][
339+
chunk_start:chunk_end
340+
]
341+
slot_mapping_seq_chunk = slot_mapping[seq_i][-current_tkv:][
342+
chunk_start:chunk_end
343+
]
344+
position_ids_seq_chunk = kwargs["position_ids"][seq_i][
345+
-current_tkv:
346+
][chunk_start:chunk_end]
347+
348+
# add the extra required padding to chunk
349+
if required_extra_pads > 0:
350+
input_ids_seq_chunk = torch.cat(
351+
(
352+
torch.zeros(
353+
required_extra_pads,
354+
dtype=torch.int64,
355+
device=input_ids_seq_chunk.device,
356+
),
357+
input_ids_seq_chunk,
358+
)
359+
)
360+
slot_mapping_seq_chunk = (
361+
pad_slots * (required_extra_pads // BLOCK_SIZE)
362+
+ slot_mapping_seq_chunk
363+
)
364+
position_ids_seq_chunk = torch.cat(
365+
(
366+
torch.zeros(
367+
required_extra_pads,
368+
dtype=torch.int64,
369+
device=position_ids_seq_chunk.device,
370+
),
371+
position_ids_seq_chunk,
372+
)
373+
)
374+
375+
input_ids_seq_chunk = input_ids_seq_chunk.unsqueeze(0).clone()
310376

311-
ids_length = input_ids[seq_i].shape[0]
312-
input_ids_seq_chunk = (
313-
input_ids[seq_i][
314-
chunk_start + ids_length : chunk_end + ids_length
315-
]
316-
.unsqueeze(0)
317-
.clone()
318-
)
319-
assert input_ids_seq_chunk.size(1) == prefill_chunk_size, (
320-
f"prefill chunk size was not equal to the chunk size. Found {input_ids_seq_chunk.size(0)}"
321-
)
322-
slots_length = len(slot_mapping[seq_i])
323377
slot_mapping_seq_chunk = (
324378
torch.tensor(
325-
slot_mapping[seq_i][
326-
chunk_start + slots_length : chunk_end
327-
+ slots_length
328-
],
379+
slot_mapping_seq_chunk,
329380
dtype=torch.int64,
330381
)
331382
.unsqueeze(0)
332383
.clone()
333384
)
334-
pids_length = kwargs["position_ids"][seq_i].shape[0]
335-
position_ids_seq_chunk = (
336-
kwargs["position_ids"][seq_i][
337-
chunk_start + pids_length : chunk_end + pids_length
338-
]
339-
.unsqueeze(0)
340-
.clone()
385+
386+
position_ids_seq_chunk = position_ids_seq_chunk.unsqueeze(
387+
0
388+
).clone()
389+
390+
assert input_ids_seq_chunk.size(1) == prefill_chunk_size, (
391+
f"prefill chunk size was not equal to the chunk size for input_ids. Found {input_ids_seq_chunk.size(0)}"
341392
)
342393

343-
# This view will result in a discontiguous tensor (creates a new graph during compile)
344-
# For this reason, we must explicitly make contiguous
345-
if left_padded_prompt_mask_seq_chunk is None:
346-
left_padded_prompt_mask_seq_chunk = (
347-
position_ids_seq_chunk == 0
348-
).sum(dim=1) - 1
349-
current_tkv_mask_seq_chunk = torch.min(
350-
torch.tensor(
351-
(chunk_j + 1) * prefill_chunk_size, dtype=torch.int64
352-
),
353-
current_tkv,
394+
assert slot_mapping_seq_chunk.size(1) == prefill_chunk_size, (
395+
f"prefill chunk size was not equal to the chunk size for slot_mapping. Found {slot_mapping_seq_chunk.size(0)}"
396+
)
397+
398+
assert position_ids_seq_chunk.size(1) == prefill_chunk_size, (
399+
f"prefill chunk size was not equal to the chunk size for position_ids. Found {position_ids_seq_chunk.size(0)}"
400+
)
401+
402+
current_tkv_mask_seq_chunk = torch.tensor(
403+
(chunk_j + 1) * prefill_chunk_size, dtype=torch.int64
354404
).unsqueeze(0)
355405

356-
table_length = len(block_table[seq_i])
357-
block_start = -current_tkv // BLOCK_SIZE + table_length
358-
block_end = chunk_end // BLOCK_SIZE + table_length
406+
block_end = chunk_end // BLOCK_SIZE
407+
# length of padding or index until padding has occured in block table
408+
block_pad_len = (input_ids.shape[1] - current_tkv) // BLOCK_SIZE
359409
block_table_seq_chunk = torch.tensor(
360-
block_table[seq_i][block_start:block_end], dtype=torch.int64
410+
[pad_block_id] * (block_seq_left_padding)
411+
+ block_table[seq_i][
412+
block_pad_len : block_pad_len + block_end
413+
],
414+
dtype=torch.int64,
361415
).unsqueeze(0)
362416

363417
chunked_kwargs = {

tests/models/test_scripts.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,25 @@ def execute_dpp(
175175
test_type,
176176
skip_validation,
177177
enforce_homogeneous_prompt_programs,
178+
prefill_chunk_size,
178179
shared_tmp_path,
179180
isolated_env,
180181
):
181182
isolated_env["VLLM_DT_MAX_BATCH_TKV_LIMIT"] = "1024"
182183
isolated_env["VLLM_DT_MAX_CONTEXT_LEN"] = "512"
183184
isolated_env["VLLM_DT_MAX_BATCH_SIZE"] = "2"
185+
if prefill_chunk_size > 0:
186+
isolated_env["VLLM_DT_CHUNK_LEN"] = f"{prefill_chunk_size}"
184187
Path(os.path.join(shared_tmp_path, "sendnn_cache")).mkdir(exist_ok=True)
185-
os.environ.setdefault(
186-
"TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache")
187-
)
188-
isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1"
188+
189+
# only enable for non-chunk
190+
if prefill_chunk_size == 0:
191+
os.environ.setdefault(
192+
"TORCH_SENDNN_CACHE_DIR", os.path.join(shared_tmp_path, "sendnn_cache")
193+
)
194+
isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "1"
195+
else:
196+
isolated_env["TORCH_SENDNN_CACHE_ENABLE"] = "0"
189197

190198
command_list = [
191199
"python3",
@@ -239,6 +247,9 @@ def execute_dpp(
239247
if enforce_homogeneous_prompt_programs:
240248
command_list += ["--enforce_homogeneous_prompt_programs"]
241249

250+
if prefill_chunk_size > 0:
251+
command_list += [f"--prefill_chunk_size={prefill_chunk_size}"]
252+
242253
# add program criteria path
243254
command_list += [
244255
f"--program_criteria_json_path={os.environ['DT_PROG_CRITERIA_FILEPATH']}"
@@ -249,21 +260,24 @@ def execute_dpp(
249260

250261
dpp_possibilities = []
251262
dpp_possibilities.append(
252-
("paged", None, 8, "sharegpt", "metrics", False, False)
263+
("paged", None, 8, "sharegpt", "metrics", False, False, 0)
253264
) # metrics and run all programs
254265
dpp_possibilities.append(
255-
("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False)
266+
("paged", "*:0,==256", 65, "sharegpt", "tokens", False, False, 0)
256267
) # tokens and run all programs that satisfy 256 sequence length
257268
dpp_possibilities.append(
258-
("paged", "*:>=2,0", 65, "sharegpt", None, True, True)
269+
("paged", "*:>=2,0", 65, "sharegpt", None, True, True, 0)
259270
) # metrics and run all programs that have >=2 batch size
260271
dpp_possibilities.append(
261-
("paged", None, 8, "custom", "tokens", False, False)
272+
("paged", None, 8, "custom", "tokens", False, False, 0)
262273
) # tokens running with specific custom dataset
274+
dpp_possibilities.append(
275+
("paged", None, 8, "sharegpt", "tokens", False, False, 128)
276+
) # metrics and run all programs with chunked prefill
263277

264278

265279
@pytest.mark.parametrize(
266-
"attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs",
280+
"attn_type,programs,max_new_tokens,dataset_type,test_type,skip_validation,enforce_homogeneous_prompt_programs,prefill_chunk_size",
267281
dpp_possibilities,
268282
)
269283
def test_dpp_script(
@@ -274,6 +288,7 @@ def test_dpp_script(
274288
test_type,
275289
skip_validation,
276290
enforce_homogeneous_prompt_programs,
291+
prefill_chunk_size,
277292
shared_tmp_path,
278293
isolated_env,
279294
):
@@ -290,6 +305,7 @@ def test_dpp_script(
290305
test_type,
291306
skip_validation,
292307
enforce_homogeneous_prompt_programs,
308+
prefill_chunk_size,
293309
shared_tmp_path,
294310
isolated_env,
295311
)

0 commit comments

Comments
 (0)