@@ -1491,12 +1491,11 @@ def pseudo_speculative_generate(
1491
1491
1492
1492
draft_tokens = []
1493
1493
for _ in range (steps ):
1494
- if self .eagle_config .parallel_draft_step > 1 :
1495
- for i in range (self .eagle_config .parallel_draft_step - 1 ):
1496
- eagle_ids = torch .cat (
1497
- (eagle_ids , getattr (self , f"mask_token_{ i } " ).view ((1 , 1 ))), dim = - 1
1498
- )
1499
- hidden_states = torch .cat ((hidden_states , hidden_states [- 1 :]), dim = 0 )
1494
+ for i in range (self .eagle_config .parallel_draft_step - 1 ):
1495
+ eagle_ids = torch .cat (
1496
+ (eagle_ids , getattr (self , f"mask_token_{ i } " ).view ((1 , 1 ))), dim = - 1
1497
+ )
1498
+ hidden_states = torch .cat ((hidden_states , hidden_states [- 1 :]), dim = 0 )
1500
1499
padded_eagle_ids , seq_len , padded_hidden_states = right_padding (
1501
1500
eagle_ids , hidden_states
1502
1501
)
@@ -1530,31 +1529,25 @@ def pseudo_speculative_generate(
1530
1529
)
1531
1530
eagle_next_hidden_states_input = eagle_next_hidden_states_input [:seq_len , :, :]
1532
1531
1533
- if self .eagle_config .parallel_draft_step > 1 :
1534
- draft_token = (
1535
- gather_from_tensor_model_parallel_region (eagle_logits )[
1536
- - self .eagle_config .parallel_draft_step :, :, :
1537
- ]
1538
- .argmax (dim = - 1 )
1539
- .transpose (0 , 1 )
1540
- )
1541
- else :
1542
- draft_token = (
1543
- gather_from_tensor_model_parallel_region (eagle_logits )[- 1 :, :, :]
1544
- .argmax (dim = - 1 )
1545
- .transpose (0 , 1 )
1546
- )
1532
+ draft_token = (
1533
+ gather_from_tensor_model_parallel_region (eagle_logits )[
1534
+ - self .eagle_config .parallel_draft_step :, :, :
1535
+ ]
1536
+ .argmax (dim = - 1 )
1537
+ .transpose (0 , 1 )
1538
+ )
1547
1539
if self .eagle_config .draft_vocab_size != self .eagle_config .vocab_size :
1548
1540
draft_token += self .eagle_module .d2t [draft_token ]
1549
1541
1550
- if self .eagle_config .parallel_draft_step > 1 :
1551
- return base_token , draft_token
1552
-
1553
1542
draft_tokens .append (draft_token )
1554
1543
1555
1544
eagle_ids = torch .cat ((eagle_ids , draft_token ), dim = - 1 )
1556
1545
hidden_states = torch .cat (
1557
- (hidden_states , eagle_next_hidden_states_input [- 1 :, :, :]), dim = 0
1546
+ (
1547
+ hidden_states ,
1548
+ eagle_next_hidden_states_input [- self .eagle_config .parallel_draft_step :, :, :],
1549
+ ),
1550
+ dim = 0 ,
1558
1551
)
1559
1552
1560
1553
draft_tokens = torch .cat (draft_tokens , dim = - 1 )
0 commit comments