Skip to content

Commit 96d71ae

Browse files
committed
fix type; support ptq; address comments;
Signed-off-by: h-guo18 <[email protected]>
1 parent 9c791d9 commit 96d71ae

File tree

3 files changed

+67
-50
lines changed

3 files changed

+67
-50
lines changed

examples/speculative_decoding/eagle_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
"rope_type": "llama3"
88
},
99
"initializer_range": 0.02,
10-
"attn_implementation": "flex_attention"
10+
"_attn_implementation": "sdpa"
1111
}

examples/speculative_decoding/launch_train.sh

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ while [ $# -gt 0 ]; do
9090
if [[ "$1" != *=* ]]; then shift; fi
9191
VLM_IMG_DIR="${1#*=}"
9292
;;
93+
--ar_validate_steps*)
94+
if [[ "$1" != *=* ]]; then shift; fi
95+
AR_VALIDATE_STEPS="${1#*=}"
96+
;;
9397
*)
9498
>&2 printf "Error: Invalid argument ${1#*=}\n"
9599
exit 1
@@ -125,6 +129,7 @@ OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
125129
DISABLE_TQDM=${DISABLE_TQDM:-False}
126130
VLM_PROCESSOR=${VLM_PROCESSOR:-}
127131
VLM_IMG_DIR=${VLM_IMG_DIR:-}
132+
AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000}
128133

129134
if [[ "$MODE" == "medusa" ]]; then
130135
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
@@ -187,9 +192,10 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
187192
--tf32 True \
188193
--data_path $DATA \
189194
--disable_tqdm $DISABLE_TQDM \
195+
--ar_validate_steps $AR_VALIDATE_STEPS \
190196
$VLM_ARGS \
191197
$OFFLINE_TRAINING_ARGS \
192-
$SPECULATIVE_ARGS
198+
$SPECULATIVE_ARGS \
193199
"
194200

195201
start_time=$(date +%s)

modelopt/torch/speculative/plugins/transformers.py

Lines changed: 59 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -404,17 +404,31 @@ def _collect_aux_hidden_states_forward_hook(self, module, input, output) -> None
404404
)
405405
self._aux_hidden_states.append(hidden_states)
406406

407-
def pop_aux_hidden_states(self):
408-
"""Return aux hidden states from base model, and clear the list."""
407+
def pop_and_gather_aux_hiddens(self):
408+
"""Pop auxiliary hidden states from base model and gather them on the draft model device."""
409409
# In PTQ, forward method will be called with try and except to find max batch size.
410410
# This leads to uncleared aux hidden states in the front of the list.
411411
# To fix it, we only return the last num_aux_h items in the list.
412412
num_aux_h = len(self.eagle_config.eagle_aux_hidden_state_layer_ids)
413413
aux_h_list = self._aux_hidden_states[-num_aux_h:]
414414
self._aux_hidden_states.clear()
415415

416+
# Gather aux hidden states on the draft model device
417+
aux_h_list = [h.to(self.eagle_module.fc.weight.device) for h in aux_h_list]
418+
416419
return aux_h_list
417420

421+
def _get_eagle_device(self):
422+
"""Return the device where we should place eagle module."""
423+
if self.eagle_offline:
424+
# For offline training, the base model has no layers.
425+
# Read the device from the base model lm_head instead.
426+
return self._base_model_lm_head.weight.device
427+
else:
428+
# When there is a base model, put eagle on the last layer's device.
429+
base_model_last_layer = self._base_model.layers[-1]
430+
return next(base_model_last_layer.parameters()).device
431+
418432
def modify(
419433
self,
420434
eagle_offline,
@@ -469,7 +483,7 @@ def modify(
469483

470484
# find base model, lm head, and embeddings paths
471485
self._find_base_model_parts()
472-
self.eagle_module.to(self._base_model.dtype).to(self._base_model_lm_head.weight.device)
486+
self.eagle_module.to(self._base_model.dtype).to(self._get_eagle_device())
473487

474488
# Make sure word embedding and lm head are frozen
475489
for param in self._base_model_embeddings.parameters():
@@ -777,52 +791,52 @@ def forward(
777791
# ====Run eagle forward====
778792
eagle_loss = None
779793
train_accs = []
780-
if self.training:
781-
# In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers
782-
b, seq_length, h = base_model_hidden_states.shape
783-
if self.eagle_config.use_aux_hidden_state:
784-
if "base_model_outputs" in kwargs:
785-
aux_hidden_states = kwargs["base_model_outputs"]["aux_hidden_states"]
786-
else:
787-
aux_hidden_states = torch.cat(self.pop_aux_hidden_states(), dim=-1)
788-
eagle_input_hidden_states = self.eagle_module.fc(aux_hidden_states)
794+
# In EAGLE-3, we have an additional FC layer to concentrate hidden states from multiple base model layers
795+
b, seq_length, h = base_model_hidden_states.shape
796+
if self.eagle_config.use_aux_hidden_state:
797+
if "base_model_outputs" in kwargs:
798+
aux_hidden_states = kwargs["base_model_outputs"]["aux_hidden_states"]
789799
else:
790-
eagle_input_hidden_states = base_model_hidden_states
800+
aux_hidden_states = torch.cat(self.pop_and_gather_aux_hiddens(), dim=-1)
801+
eagle_input_hidden_states = self.eagle_module.fc(aux_hidden_states)
802+
else:
803+
eagle_input_hidden_states = base_model_hidden_states
791804

792-
# Get eagle inputs for the first eagle forward pass
793-
eagle_input_ids, attention_mask_0, position_ids = self._get_eagle_module_inputs(
794-
input_ids,
795-
eagle_input_hidden_states,
796-
attention_mask,
797-
position_ids,
798-
eagle_cache,
799-
)
800-
with torch.no_grad():
801-
inputs_embeds = self._base_model_embeddings(eagle_input_ids)
802-
position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids)
805+
# Get eagle inputs for the first eagle forward pass
806+
eagle_input_ids, attention_mask_0, position_ids = self._get_eagle_module_inputs(
807+
input_ids,
808+
eagle_input_hidden_states,
809+
attention_mask,
810+
position_ids,
811+
eagle_cache,
812+
)
813+
with torch.no_grad():
814+
inputs_embeds = self._base_model_embeddings(eagle_input_ids)
815+
position_embeddings = self.eagle_rotary_emb(eagle_input_hidden_states, position_ids)
803816

804-
# Then, we run eagle forward
805-
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
806-
eagle_input_hidden_states,
807-
inputs_embeds,
808-
attention_mask_0,
809-
position_ids,
810-
position_embeddings,
811-
eagle_cache,
812-
)
817+
# Then, we run eagle forward
818+
_, eagle_prenorm_h, eagle_logits, eagle_cache = self._eagle_forward(
819+
eagle_input_hidden_states,
820+
inputs_embeds,
821+
attention_mask_0,
822+
position_ids,
823+
position_embeddings,
824+
eagle_cache,
825+
)
813826

814-
past_key_values.eagle_cache = eagle_cache
827+
past_key_values.eagle_cache = eagle_cache
815828

816-
# Compute loss on the eagle modules
817-
classification_loss, acc = self._eagle_loss(
818-
base_model_logits[:, 1:],
819-
eagle_logits[:, :-1],
820-
loss_mask[:, 1:],
821-
)
822-
eagle_loss = classification_loss
823-
train_accs.append(acc)
829+
# Compute loss on the eagle modules
830+
classification_loss, acc = self._eagle_loss(
831+
base_model_logits[:, 1:],
832+
eagle_logits[:, :-1],
833+
loss_mask[:, 1:],
834+
)
835+
eagle_loss = classification_loss
836+
train_accs.append(acc)
824837

825-
# ====Perform training-time-testing with 3 extra eagle forward passes====
838+
# ====Perform training-time-testing with 3 extra eagle forward passes====
839+
if self.training:
826840
for ttt_step in range(self.num_ttt_steps):
827841
eagle_input_hidden_states = torch.cat(
828842
(
@@ -931,7 +945,7 @@ def pseudo_speculative_generate(
931945
# Early return
932946
if steps < 1:
933947
if hasattr(self, "_aux_hidden_states"):
934-
_ = self.pop_aux_hidden_states()
948+
_ = self.pop_and_gather_aux_hiddens()
935949
return base_token, None
936950

937951
eagle_ids = torch.cat((input_ids[:, 1:], base_token), dim=-1)
@@ -940,10 +954,7 @@ def pseudo_speculative_generate(
940954
# EAGLE-3
941955
# Only the first iteration input_hidden_states are from aux_hidden_state layers
942956
# Gather _aux_hidden_states from all devices before concatenation
943-
gathered_aux_hidden_states = self.pop_aux_hidden_states()
944-
gathered_aux_hidden_states = [
945-
h.to(input_ids.device) for h in gathered_aux_hidden_states
946-
]
957+
gathered_aux_hidden_states = self.pop_and_gather_aux_hiddens()
947958
eagle_input_hidden_states = self.eagle_module.fc(
948959
torch.cat(gathered_aux_hidden_states, dim=-1)
949960
)

0 commit comments

Comments
 (0)