Skip to content

Commit 78ef849

Browse files
ydwu4ydshieh
andauthored
Avoid aliasing in cond's branches for torch 2.8 (#39488)
Avoid alaising in cond's branches Co-authored-by: Yih-Dar <[email protected]>
1 parent 9e676e6 commit 78ef849

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/transformers/generation/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -507,21 +507,22 @@ def _cache_dependant_input_preparation_exporting(
507507
# else:
508508
# if input_ids.shape[1] != cache_position.shape[0]:
509509
# input_ids = input_ids[:, cache_position]
510+
# We need to clone the outputs to avoid aliasing.
510511
def branch_1(inputs_embeds, cache_position):
511-
return inputs_embeds[:, -cache_position.shape[0] :]
512+
return inputs_embeds[:, -cache_position.shape[0] :].clone()
512513

513514
def branch_2(input_ids, cache_position):
514-
return input_ids[:, -cache_position.shape[0] :]
515+
return input_ids[:, -cache_position.shape[0] :].clone()
515516

516517
def branch_3(input_ids, cache_position):
517-
return input_ids[:, cache_position]
518+
return input_ids[:, cache_position].clone()
518519

519520
inputs_embeds, input_ids = torch.cond(
520521
input_ids.shape[1] == 0,
521522
(
522523
lambda input_ids, inputs_embeds, cache_position: (
523524
branch_1(inputs_embeds, cache_position),
524-
input_ids,
525+
input_ids.clone(),
525526
)
526527
),
527528
(
@@ -534,7 +535,7 @@ def branch_3(input_ids, cache_position):
534535
torch.cond(
535536
input_ids.shape[1] != cache_position.shape[0],
536537
branch_3,
537-
(lambda input_ids, cache_position: input_ids),
538+
(lambda input_ids, cache_position: input_ids.clone()),
538539
[input_ids, cache_position],
539540
)
540541
),

0 commit comments

Comments
 (0)