Skip to content

Commit e6334c4

Browse files
committed
address review comments
Signed-off-by: h-guo18 <[email protected]>
1 parent c79e8e2 commit e6334c4

File tree

5 files changed

+17
-194
lines changed

5 files changed

+17
-194
lines changed

examples/speculative_decoding/README.md

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,16 +184,6 @@ This will export the model from a modelopt checkpoint to a deployment-compatible
184184

185185
The exported checkpoint can be deployed on TRT-LLM or vLLM.
186186

187-
#### vLLM
188-
189-
To test AR on MT-bench with vLLM:
190-
191-
```python
192-
python vllm_inference_demo.py --base_model $BASE_MODEL --eagle_model $EXPORT_PATH --mode mt-bench
193-
```
194-
195-
Please refer to [vLLM Doc: Speculative Decoding](https://docs.vllm.ai/en/v0.9.0/features/spec_decode.html) for detailed usage.
196-
197187
#### TRT-LLM
198188

199189
To serve the checkpoint with trtllm, we can run trtllm-serve with:
@@ -223,6 +213,10 @@ kv_cache_config:
223213
224214
Please refer to [TRT-LLM Doc: Speculative Decoding](https://nvidia.github.io/TensorRT-LLM/examples/llm_speculative_decoding.html) for detailed usage.
225215
216+
#### vLLM
217+
218+
Please refer to [vLLM Doc: Speculative Decoding](https://docs.vllm.ai/en/v0.9.0/features/spec_decode.html) for detailed usage.
219+
226220
#### Deploying Quantized model
227221
228222
See more details on deployment of quantized model to TRTLLM [here](../llm_ptq/README.md).

examples/speculative_decoding/main.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@
4646
import modelopt.torch.opt as mto
4747
import modelopt.torch.speculative as mtsp
4848
from modelopt.torch.utils import print_rank_0
49-
from modelopt.torch.utils.distributed import is_master
49+
50+
try:
51+
import wandb
52+
53+
wandb.init()
54+
except ImportError:
55+
wandb = None
5056

5157
torch.manual_seed(0)
5258
mto.enable_huggingface_checkpointing()
@@ -205,15 +211,6 @@ def train():
205211
class ARValidationCallback(TrainerCallback):
206212
def __init__(self, ar_validate_steps: int = 500):
207213
self.ar_validate_steps = ar_validate_steps
208-
self.wandb = None
209-
if is_master():
210-
try:
211-
import wandb
212-
213-
self.wandb = wandb
214-
self.wandb.init()
215-
except ImportError:
216-
pass
217214

218215
def on_step_end(self, args, state, control, **kwargs):
219216
if state.global_step % self.ar_validate_steps == 0 and state.global_step > 0:
@@ -225,8 +222,8 @@ def on_step_end(self, args, state, control, **kwargs):
225222
device=kwargs["model"].device,
226223
)
227224
print_rank_0(f"Step {state.global_step} AR: {sum(ars) / len(ars):.4f}")
228-
if self.wandb:
229-
self.wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
225+
if wandb:
226+
wandb.log({"validate_ar": sum(ars) / len(ars)}, step=state.global_step)
230227
return control
231228

232229
trainer = Trainer(

examples/speculative_decoding/server_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@
4646
parser.add_argument(
4747
"--max_tokens", type=int, default=2048, help="Maximum number of tokens to generate"
4848
)
49-
parser.add_argument("--chat", action="store_true", default=True, help="Use chat mode")
49+
parser.add_argument("--chat", default=True, type=bool, help="Use chat mode")
5050
parser.add_argument("--model", type=str, default="model", help="Model name")
5151
parser.add_argument("--url", type=str, default="http://localhost:8000/v1", help="URL of the API")
5252
parser.add_argument("--api_key", type=str, default="token-abc123", help="API key (if any)")

examples/speculative_decoding/train_eagle3_and_export.sh

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ fi
5555

5656
MODEL_BASENAME=$(basename "$BASE_MODEL")
5757

58-
echo "==== [1/4] Training draft model ===="
58+
echo "==== [1/3] Training draft model ===="
5959
OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
6060
./launch_train.sh --model $BASE_MODEL \
6161
--output_dir $OUTPUT_DIR \
@@ -64,12 +64,9 @@ OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
6464
--num_epochs 2 \
6565
--eagle_config eagle_config.json
6666

67-
echo "==== [2/4] Evaluating ModelOpt checkpoint on MT-Bench ===="
67+
echo "==== [2/3] Evaluating ModelOpt checkpoint on MT-Bench ===="
6868
python ar_validate.py --model_path $OUTPUT_DIR
6969

70-
echo "==== [3/4] Exporting checkpoint to deployment format ===="
70+
echo "==== [3/3] Exporting checkpoint to deployment format ===="
7171
EXPORT_PATH=export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)
7272
python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
73-
74-
echo "==== [4/4] Text Generation with speculative decoding in vLLM===="
75-
python vllm_inference_demo.py --base-model $BASE_MODEL --eagle-model $EXPORT_PATH --mode generate --prompt "Write a short story about a cat."

examples/speculative_decoding/vllm_inference_demo.py

Lines changed: 0 additions & 165 deletions
This file was deleted.

0 commit comments

Comments
 (0)