@@ -1386,8 +1386,6 @@ def generate(
1386
1386
batch_size = input_ids .shape [0 ] // self .num_codebooks
1387
1387
1388
1388
# 4. Define other model kwargs
1389
- model_kwargs ["output_attentions" ] = generation_config .output_attentions
1390
- model_kwargs ["output_hidden_states" ] = generation_config .output_hidden_states
1391
1389
model_kwargs ["use_cache" ] = generation_config .use_cache
1392
1390
model_kwargs ["guidance_scale" ] = generation_config .guidance_scale
1393
1391
@@ -1481,14 +1479,11 @@ def generate(
1481
1479
)
1482
1480
1483
1481
# 11. run greedy search
1484
- outputs = self .greedy_search (
1482
+ outputs = self ._greedy_search (
1485
1483
input_ids ,
1486
1484
logits_processor = logits_processor ,
1487
1485
stopping_criteria = stopping_criteria ,
1488
- pad_token_id = generation_config .pad_token_id ,
1489
- eos_token_id = generation_config .eos_token_id ,
1490
- output_scores = generation_config .output_scores ,
1491
- return_dict_in_generate = generation_config .return_dict_in_generate ,
1486
+ generation_config = generation_config ,
1492
1487
synced_gpus = synced_gpus ,
1493
1488
streamer = streamer ,
1494
1489
** model_kwargs ,
@@ -1506,15 +1501,12 @@ def generate(
1506
1501
)
1507
1502
1508
1503
# 12. run sample
1509
- outputs = self .sample (
1504
+ outputs = self ._sample (
1510
1505
input_ids ,
1511
1506
logits_processor = logits_processor ,
1512
1507
logits_warper = logits_warper ,
1513
1508
stopping_criteria = stopping_criteria ,
1514
- pad_token_id = generation_config .pad_token_id ,
1515
- eos_token_id = generation_config .eos_token_id ,
1516
- output_scores = generation_config .output_scores ,
1517
- return_dict_in_generate = generation_config .return_dict_in_generate ,
1509
+ generation_config = generation_config ,
1518
1510
synced_gpus = synced_gpus ,
1519
1511
streamer = streamer ,
1520
1512
** model_kwargs ,
@@ -2198,8 +2190,8 @@ def _prepare_text_encoder_kwargs_for_generation(
2198
2190
self ,
2199
2191
inputs_tensor : torch .Tensor ,
2200
2192
model_kwargs ,
2201
- model_input_name : Optional [str ] = None ,
2202
- guidance_scale : Optional [ float ] = None ,
2193
+ model_input_name : Optional [str ],
2194
+ generation_config : GenerationConfig ,
2203
2195
) -> Dict [str , Any ]:
2204
2196
# 1. get text encoder
2205
2197
encoder = self .get_text_encoder ()
@@ -2221,6 +2213,9 @@ def _prepare_text_encoder_kwargs_for_generation(
2221
2213
encoder_kwargs = {
2222
2214
argument : value for argument , value in encoder_kwargs .items () if argument in encoder_signature
2223
2215
}
2216
+ encoder_kwargs ["output_attentions" ] = generation_config .output_attentions
2217
+ encoder_kwargs ["output_hidden_states" ] = generation_config .output_hidden_states
2218
+ guidance_scale = generation_config .guidance_scale
2224
2219
2225
2220
# 3. make sure that encoder returns `ModelOutput`
2226
2221
model_input_name = model_input_name if model_input_name is not None else self .text_encoder .main_input_name
@@ -2452,8 +2447,6 @@ def generate(
2452
2447
batch_size = inputs_tensor .shape [0 ]
2453
2448
2454
2449
# 4. Define other model kwargs
2455
- model_kwargs ["output_attentions" ] = generation_config .output_attentions
2456
- model_kwargs ["output_hidden_states" ] = generation_config .output_hidden_states
2457
2450
model_kwargs ["use_cache" ] = generation_config .use_cache
2458
2451
model_kwargs ["guidance_scale" ] = generation_config .guidance_scale
2459
2452
@@ -2467,10 +2460,7 @@ def generate(
2467
2460
if "encoder_outputs" not in model_kwargs :
2468
2461
# encoder_outputs are created and added to `model_kwargs`
2469
2462
model_kwargs = self ._prepare_text_encoder_kwargs_for_generation (
2470
- inputs_tensor ,
2471
- model_kwargs ,
2472
- model_input_name ,
2473
- guidance_scale = generation_config .guidance_scale ,
2463
+ inputs_tensor , model_kwargs , model_input_name , generation_config ,
2474
2464
)
2475
2465
2476
2466
if "prompt_hidden_states" not in model_kwargs and "prompt_input_ids" in model_kwargs :
@@ -2579,14 +2569,11 @@ def generate(
2579
2569
)
2580
2570
2581
2571
# 11. run greedy search
2582
- outputs = self .greedy_search (
2572
+ outputs = self ._greedy_search (
2583
2573
input_ids ,
2584
2574
logits_processor = logits_processor ,
2585
2575
stopping_criteria = stopping_criteria ,
2586
- pad_token_id = generation_config .pad_token_id ,
2587
- eos_token_id = generation_config .eos_token_id ,
2588
- output_scores = generation_config .output_scores ,
2589
- return_dict_in_generate = generation_config .return_dict_in_generate ,
2576
+ generation_config = generation_config ,
2590
2577
synced_gpus = synced_gpus ,
2591
2578
streamer = streamer ,
2592
2579
** model_kwargs ,
@@ -2605,15 +2592,12 @@ def generate(
2605
2592
)
2606
2593
2607
2594
# 12. run sample
2608
- outputs = self .sample (
2595
+ outputs = self ._sample (
2609
2596
input_ids ,
2610
2597
logits_processor = logits_processor ,
2611
2598
logits_warper = logits_warper ,
2612
2599
stopping_criteria = stopping_criteria ,
2613
- pad_token_id = generation_config .pad_token_id ,
2614
- eos_token_id = generation_config .eos_token_id ,
2615
- output_scores = generation_config .output_scores ,
2616
- return_dict_in_generate = generation_config .return_dict_in_generate ,
2600
+ generation_config = generation_config ,
2617
2601
synced_gpus = synced_gpus ,
2618
2602
streamer = streamer ,
2619
2603
** model_kwargs ,
0 commit comments