Skip to content

Bugs in the speculative decoding example #211

@Framartin

Description

@Framartin

Describe the bug

I encountered two issues while preparing the data of the speculative decoding example with the examples/speculative_decoding/server_generate.py file:

  1. the prompt argument of the function client.chat.completions.create(), used in line 125, does seem compatible with the recent versions of the openai package (1.86.0)
    • The error is the following: "Missing required arguments; Expected either ('messages' and 'model') or ('messages', 'model' and 'stream') arguments to be given"
    • The version of the openai package is missing in examples/speculative_decoding/requirements.txt.
    • I think it would be best to replace the prompt argument by the messages arguments.
  2. the --system_prompt command line argument of the server_generate.py script seems to be ignored
    • system_prompt is only used with the --chat flag, which is disabled by default, and not used in the example.
    • The command provided in the README (python3 server_generate.py --data_path Daring-Anteater/train.jsonl --output_path finetune/data.jsonl --max_token 512 --system_prompt <system_prompt_text>) does not match its description: "To add a system prompt, use the --system_prompt argument"
    • No error is given, since the system prompt is silently ignored.
  3. the --chat flag is not given to server_generate.py but the training script seems to use the chat data format
    • The command provided in the README (python3 server_generate.py --data_path Daring-Anteater/train.jsonl --output_path finetune/data.jsonl --max_token 512 --system_prompt <system_prompt_text>) does not include --chat
    • So the data saved are of the format {"text": prompt + response}
    • But the training script examples/speculative_decoding/main.py seems to expect data with the format {"conversation_id": idx, "conversations": output_messages}, according to the preprocess() function of both examples/speculative_decoding/eagle_utils.py and examples/speculative_decoding/medusa_utils.py
    • Could you confirm that server_generate.py should be run with --chat?

Steps/Code to reproduce bug

I followed the instructions of the README to prepare the data of the speculative decoding example:

# serve the model
vllm serve TinyLlama/TinyLlama-1.1B-Chat-v1.0 --api-key token-abc123 --port 8000  --tensor-parallel-size 1

# try to genereate the data
python3 server_generate.py --data_path ./Daring-Anteater/train.jsonl --output_path finetune/data.jsonl --max_token 512 --system_prompt "You are a helpful assistant."

Expected behavior

  1. No exception raised when running server_generate.py by calling the openai completion API.
  2. I expect the system prompt to be added in the prompt fed to the model in examples/speculative_decoding/server_generate.py line 123.
  3. No KeyError: 'conversations' when running examples/speculative_decoding/launch.sh

System information

  • Container used (if applicable): nvcr.io/nvidia/pytorch:24.12-py3
  • OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): Ubuntu 24.04.1 LTS
  • CPU architecture (x86_64, aarch64): x86_64
  • GPU name (e.g. H100, A100, L40S): Tesla V100-PCIE-32GB
  • GPU memory size: 32.0 GB
  • Number of GPUs: 1
  • Library versions (if applicable):
    • Python: 3.12.3
    • ModelOpt version or commit hash: 7af33d2 (example script)
    • CUDA: 12.6
    • PyTorch: 2.7.0+cu126
    • Transformers: 4.52.4
    • TensorRT-LLM: ?
    • ONNXRuntime: ?
    • TensorRT: ?
Click to expand: Python script to automatically collect system information
import platform
import re
import subprocess


def get_nvidia_gpu_info():
    try:
        nvidia_smi = (
            subprocess.check_output(
                "nvidia-smi --query-gpu=name,memory.total,count --format=csv,noheader,nounits",
                shell=True,
            )
            .decode("utf-8")
            .strip()
            .split("\n")
        )
        if len(nvidia_smi) > 0:
            gpu_name = nvidia_smi[0].split(",")[0].strip()
            gpu_memory = round(float(nvidia_smi[0].split(",")[1].strip()) / 1024, 1)
            gpu_count = len(nvidia_smi)
            return gpu_name, f"{gpu_memory} GB", gpu_count
    except Exception:
        return "?", "?", "?"


def get_cuda_version():
    try:
        nvcc_output = subprocess.check_output("nvcc --version", shell=True).decode("utf-8")
        match = re.search(r"release (\d+\.\d+)", nvcc_output)
        if match:
            return match.group(1)
    except Exception:
        return "?"


def get_package_version(package):
    try:
        return getattr(__import__(package), "__version__", "?")
    except Exception:
        return "?"


# Get system info
os_info = f"{platform.system()} {platform.release()}"
if platform.system() == "Linux":
    try:
        os_info = (
            subprocess.check_output("cat /etc/os-release | grep PRETTY_NAME | cut -d= -f2", shell=True)
            .decode("utf-8")
            .strip()
            .strip('"')
        )
    except Exception:
        pass
elif platform.system() == "Windows":
    print("Please add the `windows` label to the issue.")

cpu_arch = platform.machine()
gpu_name, gpu_memory, gpu_count = get_nvidia_gpu_info()
cuda_version = get_cuda_version()

# Print system information in the format required for the issue template
print("=" * 70)
print("- Container used (if applicable): " + "?")
print("- OS (e.g., Ubuntu 22.04, CentOS 7, Windows 10): " + os_info)
print("- CPU architecture (x86_64, aarch64): " + cpu_arch)
print("- GPU name (e.g. H100, A100, L40S): " + gpu_name)
print("- GPU memory size: " + gpu_memory)
print("- Number of GPUs: " + str(gpu_count))
print("- Library versions (if applicable):")
print("  - Python: " + platform.python_version())
print("  - ModelOpt version or commit hash: " + get_package_version("modelopt"))
print("  - CUDA: " + cuda_version)
print("  - PyTorch: " + get_package_version("torch"))
print("  - Transformers: " + get_package_version("transformers"))
print("  - TensorRT-LLM: " + get_package_version("tensorrt_llm"))
print("  - ONNXRuntime: " + get_package_version("onnxruntime"))
print("  - TensorRT: " + get_package_version("tensorrt"))
print("- Any other details that may help: " + "?")
print("=" * 70)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions