Skip to content

Commit b02a3b4

Browse files
committed
Update EAGLE offline flags to use EagleConfig
Signed-off-by: Benjamin Chislett <[email protected]>
1 parent 6c32d54 commit b02a3b4

File tree

8 files changed

+75
-255
lines changed

8 files changed

+75
-255
lines changed

examples/speculative_decoding/collect_hidden_states/send_conversations_for_hiddens.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -166,12 +166,6 @@ async def main(args: argparse.Namespace) -> None:
166166

167167
try:
168168
# Send the message to the OpenAI-compatible endpoint
169-
# await client.chat.completions.create(
170-
# model=args.model,
171-
# messages=conversations,
172-
# temperature=0.0,
173-
# max_tokens=1,
174-
# )
175169
await client.completions.create(
176170
model=args.model,
177171
prompt=input_string,

examples/speculative_decoding/eagle_utils.py

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,17 @@ def __getitem__(self, i) -> dict[str, torch.Tensor]:
225225
raise ValueError(msg)
226226

227227
ret = {**preprocessed_base} # Shallow copy so we don't accidentally modify the cache
228-
ret["hidden_states"] = offline_data["hidden_states"]
229-
ret["aux_hidden_states"] = offline_data["aux_hidden_states"]
230-
228+
ret["kwargs"] = {
229+
"base_model_outputs": {
230+
"base_model_hidden_states": offline_data["hidden_states"],
231+
"aux_hidden_states": offline_data["aux_hidden_states"],
232+
}
233+
}
231234
return ret
232235

233236

234237
def make_eagle_supervised_data_module(
235-
tokenizer: transformers.PreTrainedTokenizer, data_args
238+
tokenizer: transformers.PreTrainedTokenizer, data_args, use_offline_training: bool
236239
) -> dict:
237240
"""Make dataset and collator for supervised fine-tuning.
238241
@@ -250,11 +253,14 @@ def make_eagle_supervised_data_module(
250253
else:
251254
data_json = json.load(f)
252255

253-
if data_args.offline_training:
256+
if use_offline_training:
254257
print_rank_0("Loading pre-processed data for offline training...")
255258
dataset_cls = OfflineSupervisedDataset
256259

257260
# Glob for all .pt files in the data_path directory
261+
assert data_args.offline_data_path is not None, (
262+
"offline_data_path must be provided for offline training."
263+
)
258264
offline_data_path = Path(data_args.offline_data_path)
259265
all_files = {str(p) for p in offline_data_path.glob("*.pt")}
260266
if not all_files:
@@ -346,24 +352,30 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
346352
class DataCollatorForOffline(DataCollatorWithPadding):
347353
def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
348354
base_batch = super().__call__(features)
349-
if "hidden_states" not in features[0]:
350-
print(features[0].keys())
351-
print(features[0])
352-
print(features)
353-
raise ValueError("Features do not contain 'hidden_states' key.")
354-
max_hs_length = max(item["hidden_states"].shape[0] for item in features)
355+
if "kwargs" not in features[0]:
356+
raise ValueError("No kwargs found in batch features. Offline data required.")
357+
358+
features = [item["kwargs"]["base_model_outputs"] for item in features]
359+
max_hs_length = max(item["base_model_hidden_states"].shape[0] for item in features)
355360

356361
batch_hidden_states = torch.stack(
357-
[self.paddingtensor2d(item["hidden_states"], max_hs_length) for item in features]
362+
[
363+
self.paddingtensor2d(item["base_model_hidden_states"], max_hs_length)
364+
for item in features
365+
]
358366
)
359367
batch_aux_hidden_states = torch.stack(
360368
[self.paddingtensor2d(item["aux_hidden_states"], max_hs_length) for item in features]
361369
)
362370

363371
batch = {
364372
**base_batch,
365-
"hidden_states": batch_hidden_states,
366-
"aux_hidden_states": batch_aux_hidden_states,
373+
"kwargs": {
374+
"base_model_outputs": {
375+
"base_model_hidden_states": batch_hidden_states,
376+
"aux_hidden_states": batch_aux_hidden_states,
377+
}
378+
},
367379
}
368380

369381
return batch

examples/speculative_decoding/gen_synthetic_conversations/run_vllm_server.sh

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,4 @@
1414
# limitations under the License.
1515

1616
# Example launch configuration for a vLLM server
17-
# On 8xB200, Llama 3.3 70B runs comfortably with TP=2 at high batch sizes.
18-
19-
# Achieve data parallelism by running multiple vLLM servers on different GPUs.
20-
# CUDA_VISIBLE_DEVICES=0,1 vllm serve meta-llama/Llama-3.3-70B-Instruct --tensor-parallel-size 2 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8000 &
21-
# CUDA_VISIBLE_DEVICES=2,3 vllm serve meta-llama/Llama-3.3-70B-Instruct --tensor-parallel-size 2 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8001 &
22-
# CUDA_VISIBLE_DEVICES=4,5 vllm serve meta-llama/Llama-3.3-70B-Instruct --tensor-parallel-size 2 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8002 &
23-
# CUDA_VISIBLE_DEVICES=6,7 vllm serve meta-llama/Llama-3.3-70B-Instruct --tensor-parallel-size 2 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8003 &
24-
25-
# Alternatively, use vLLM's built-in data parallelism.
26-
# vllm serve meta-llama/Llama-3.3-70B-Instruct --tensor-parallel-size 2 --data-parallel-size 4 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8000
27-
28-
# Default to a small model for testing.
29-
# vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --data-parallel-size 8 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8000
30-
31-
CUDA_VISIBLE_DEVICES=0 vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-seq-len 8192 --max-num-seqs 1024 --port 8000
32-
# CUDA_VISIBLE_DEVICES=1 vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8001 &
33-
# CUDA_VISIBLE_DEVICES=2 vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8002 &
34-
# CUDA_VISIBLE_DEVICES=3 vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8003 &
35-
# CUDA_VISIBLE_DEVICES=4 vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8004 &
36-
# CUDA_VISIBLE_DEVICES=5 vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8005 &
37-
# CUDA_VISIBLE_DEVICES=6 vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8006 &
38-
# CUDA_VISIBLE_DEVICES=7 vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8007 &
17+
vllm serve meta-llama/Llama-3.2-1B-Instruct --tensor-parallel-size 1 --data-parallel-size 8 --max-num-batched-tokens 32768 --max-seq-len 8192 --disable-log-requests --max-num-seqs 1024 --port 8000

examples/speculative_decoding/launch.sh

Lines changed: 0 additions & 175 deletions
This file was deleted.

examples/speculative_decoding/launch_train.sh

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ while [ $# -gt 0 ]; do
3030
if [[ "$1" != *=* ]]; then shift; fi
3131
DATA="${1#*=}"
3232
;;
33+
--offline-data*)
34+
if [[ "$1" != *=* ]]; then shift; fi
35+
OFFLINE_DATA_PATH="${1#*=}"
36+
;;
3337
--mode*)
3438
if [[ "$1" != *=* ]]; then shift; fi
3539
MODE="${1#*=}"
@@ -87,7 +91,7 @@ set -x
8791
# Get the default value for save_steps based on the available number of GPUs
8892
GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())")
8993
# Calculate save_steps
90-
DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT))
94+
DEFAULT_SAVE_STEPS=$((192 / GPU_COUNT))
9195

9296
MODEL=${MODEL:-"TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
9397
MODE=${MODE:-"eagle3"}
@@ -104,7 +108,8 @@ REDRAFTER_TOKENS=${REDRAFTER_TOKENS:-1}
104108
REDRAFTER_NUM_LAYERS=${REDRAFTER_NUM_LAYERS:-1}
105109
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP=${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-"LlamaDecoderLayer"}
106110
NUM_GPU=${NUM_GPU:-1}
107-
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-512}
111+
TRAINING_SEQ_LEN=${TRAINING_SEQ_LEN:-2048}
112+
OFFLINE_DATA_PATH=${OFFLINE_DATA_PATH:-""}
108113

109114
if [[ "$MODE" == "medusa" ]]; then
110115
SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS"
@@ -119,6 +124,17 @@ else
119124
exit 1
120125
fi
121126

127+
if [[ "$OFFLINE_DATA_PATH" != "" ]]; then
128+
if [[ ! -d "$OFFLINE_DATA_PATH" ]]; then
129+
echo "Offline data path $OFFLINE_DATA_PATH does not exist or is not a directory."
130+
exit 1
131+
else
132+
OFFLINE_TRAINING_ARGS="--offline-data-path $OFFLINE_DATA_PATH"
133+
fi
134+
else
135+
OFFLINE_TRAINING_ARGS=""
136+
fi
137+
122138
if [[ "$NUM_GPU" == 1 ]]; then
123139
MULTI_GPU=""
124140
else
@@ -149,6 +165,7 @@ CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
149165
--logging_steps 100 \
150166
--tf32 True \
151167
--data_path $DATA \
168+
$OFFLINE_TRAINING_ARGS \
152169
$SPECULATIVE_ARGS
153170
"
154171

0 commit comments

Comments
 (0)