File tree Expand file tree Collapse file tree 1 file changed +12
-8
lines changed
modelopt/torch/speculative/plugins Expand file tree Collapse file tree 1 file changed +12
-8
lines changed Original file line number Diff line number Diff line change @@ -790,7 +790,9 @@ def _get_eagle_module_inputs(
790
790
eagle_inputs ["position_ids" ] = torch .empty (
791
791
0 , dtype = position_ids .dtype , device = position_ids .device
792
792
)
793
- eagle_inputs ["rotary_pos_emb" ] = rotary_pos_emb
793
+ eagle_inputs ["rotary_pos_emb" ] = torch .empty (
794
+ 0 , dtype = rotary_pos_emb .dtype , device = rotary_pos_emb .device
795
+ )
794
796
if self .config .sequence_parallel :
795
797
gathered_hidden_states = gather_from_sequence_parallel_region (hidden_states )
796
798
gathered_features = (
@@ -831,13 +833,15 @@ def _get_eagle_module_inputs(
831
833
eagle_inputs ["hidden_states" ],
832
834
gathered_hidden_states
833
835
if step == 0
834
- else (
835
- torch .zeros (
836
- (1 , b , h ),
837
- dtype = hidden_states .dtype ,
838
- device = hidden_states .device ,
839
- ),
840
- feature [:- 1 , :, :],
836
+ else torch .cat (
837
+ (
838
+ torch .zeros (
839
+ (1 , b , h ),
840
+ dtype = hidden_states .dtype ,
841
+ device = hidden_states .device ,
842
+ ),
843
+ feature [:- 1 , :, :],
844
+ )
841
845
),
842
846
),
843
847
dim = 0 ,
You can’t perform that action at this time.
0 commit comments