@@ -507,21 +507,22 @@ def _cache_dependant_input_preparation_exporting(
507
507
# else:
508
508
# if input_ids.shape[1] != cache_position.shape[0]:
509
509
# input_ids = input_ids[:, cache_position]
510
+ # We need to clone the outputs to avoid aliasing.
510
511
def branch_1 (inputs_embeds , cache_position ):
511
- return inputs_embeds [:, - cache_position .shape [0 ] :]
512
+ return inputs_embeds [:, - cache_position .shape [0 ] :]. clone ()
512
513
513
514
def branch_2 (input_ids , cache_position ):
514
- return input_ids [:, - cache_position .shape [0 ] :]
515
+ return input_ids [:, - cache_position .shape [0 ] :]. clone ()
515
516
516
517
def branch_3 (input_ids , cache_position ):
517
- return input_ids [:, cache_position ]
518
+ return input_ids [:, cache_position ]. clone ()
518
519
519
520
inputs_embeds , input_ids = torch .cond (
520
521
input_ids .shape [1 ] == 0 ,
521
522
(
522
523
lambda input_ids , inputs_embeds , cache_position : (
523
524
branch_1 (inputs_embeds , cache_position ),
524
- input_ids ,
525
+ input_ids . clone () ,
525
526
)
526
527
),
527
528
(
@@ -534,7 +535,7 @@ def branch_3(input_ids, cache_position):
534
535
torch .cond (
535
536
input_ids .shape [1 ] != cache_position .shape [0 ],
536
537
branch_3 ,
537
- (lambda input_ids , cache_position : input_ids ),
538
+ (lambda input_ids , cache_position : input_ids . clone () ),
538
539
[input_ids , cache_position ],
539
540
)
540
541
),
0 commit comments