Skip to content
Merged
58 changes: 58 additions & 0 deletions examples/05_Llama3.1-8B_Example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Running Endpoints with [Llama3.1-8B](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)

## Download dataset

The Llama3.1-8B benchmark uses the [cnn/dailymail](https://huggingface.co/datasets/abisee/cnn_dailymail) dataset (for summarization). Download, modify the input prompt and save it using the following command:

```
python download_cnndm.py --save-dir data --split validation
# Processed data will be saved at data/cnn_dailymail_validation.json
```

- To generate calibration dataset, users can use the [cnn-dailymail-calibration-list](https://github.com/mlcommons/inference/blob/master/calibration/CNNDailyMail/calibration-list.txt)

```
curl -OL https://raw.githubusercontent.com/mlcommons/inference/v4.0/calibration/CNNDailyMail/calibration-list.txt
python download_cnndm.py --save-dir data --calibration-ids-file calibration-list.txt --split train
```

## Launch the server

The following environment variables are used by the commands below to make the scripts easier to run

```
export HF_TOKEN=<your Hugging Face token>
export HF_HOME=<Path to your hf_home, usually /USERNAME/.cache/huggingface>
export MODEL_NAME=<model to run, for instance meta-llama/Llama-3.1-8B-Instruct>
```

It is convenient to download the model prior to launch so that the container can reuse the model instead of having to download it post-launch. This can be done via `hf download $MODEL_NAME`. The models downloaded can be verified via `hf cache scan`

### [vLLM](https://github.com/vllm-project/vllm)

We can launch the latest docker image for vllm using the command below:

```
docker run --runtime nvidia --gpus all -v ${HF_HOME}:/root/.cache/huggingface --env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" -p 8000:8000 --ipc=host vllm/vllm-openai:latest --model ${MODEL_NAME}

```

### To run Offline mode

**Note** Double-check the config file for correct parameters

- Launch the benchmark with config yaml

```
inference-endpoint benchmark from-config -c offline_llama3_8b_cnn.yaml --timeout 600
```

### To run Online mode

**Note** Double-check the config file for correct parameters

- Launch the benchmark with config yaml

```
inference-endpoint benchmark from-config -c online_llama3_8b_cnn.yaml --timeout 600
```
110 changes: 110 additions & 0 deletions examples/05_Llama3.1-8B_Example/download_cnndm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
import warnings
from argparse import ArgumentParser

from datasets import load_dataset
from tqdm import tqdm

PROMPT = "Summarize the following news article in 128 tokens. Please output the summary only, without any other text.\n\nArticle:\n{input}\n\nSummary:"


def download_cnndm(
save_dir: str, split: str = "validation", calibration_ids_file: str = None
) -> None:
"""Download the CNN/DailyMail dataset and save it to the specified directory.

Args:
save_dir (str): The directory where the dataset will be saved.
split (str): The dataset split to download (default: validation).
calibration_ids_file (str): Path to a file containing calibration IDs (one per line).
If provided, 'split' must be 'train' and only examples with these IDs will be prepared and saved.
"""
os.makedirs(save_dir, exist_ok=True)
dataset = load_dataset("cnn_dailymail", "3.0.0", split=split)

output_file_tag = split
calibration_ids = set()
if calibration_ids_file:
with open(calibration_ids_file, encoding="utf-8") as id_file:
for line in id_file:
calibration_ids.add(line.strip())
output_file_tag = "calibration"
fname = f"cnn_dailymail_{output_file_tag}.json"
output_file = os.path.join(save_dir, fname)

# Add the custom prompt to each example and filter if calibration IDs are provided
with open(output_file, "w", encoding="utf-8") as f:
for _, example in tqdm(
enumerate(dataset), total=len(dataset), desc="Processing examples"
):
if calibration_ids and str(example["id"]) not in calibration_ids:
continue
f.write(
json.dumps(
{
"id": example["id"],
"input": PROMPT.format(input=example["article"]),
"highlights": example["highlights"],
}
)
+ "\n"
)
print(f"Dataset saved to {output_file}")


if __name__ == "__main__":
parser = ArgumentParser(description="Download CNN/DailyMail dataset")
parser.add_argument(
"--save-dir",
type=str,
required=True,
help="Directory to save the downloaded dataset",
)
parser.add_argument(
"--split",
type=str,
default="validation",
help="Dataset split to download (default: validation)",
)
parser.add_argument(
"--calibration-ids-file",
type=str,
default=None,
help="Path to a file containing calibration IDs (one per line)."
" If provided, 'split' must be 'train' and only examples with these IDs will be saved.",
)

args = parser.parse_args()
if args.calibration_ids_file and args.split != "train":
warnings.warn(
"When --calibration-ids-file is provided, --split must be 'train'. Setting split to 'train'.",
stacklevel=2,
)
args.split = "train"

if args.calibration_ids_file and not os.path.isfile(args.calibration_ids_file):
raise FileNotFoundError(
f"Provided calibration IDs file not found: {args.calibration_ids_file}"
)

download_cnndm(
save_dir=args.save_dir,
split=args.split,
calibration_ids_file=args.calibration_ids_file,
)
42 changes: 42 additions & 0 deletions examples/05_Llama3.1-8B_Example/offline_llama3_8b_cnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Offline Throughput Benchmark
name: "offline-llama3.1-8b-cnn-benchmark"
version: "1.0"
type: "offline"

model_params:
name: "meta-llama/Llama-3.1-8B-Instruct" # Path to the model
temperature: 1.0
top_p: 1.0
max_new_tokens: 128

datasets:
- name: "perf-test"
type: "performance"
path: "data/cnn_dailymail_validation.json" # Path to the dataset. Note: This should be the file generated with download_cnndm.py
samples: 13368 # Number of samples in the dataset (validation has ~13.3k samples, while training has ~287k samples)
parser:
prompt: "input"

settings:
runtime:
min_duration_ms: 6000 # 6 seconds
max_duration_ms: 60000 # 1 minute
scheduler_random_seed: 137 # For Poisson/distribution sampling
dataloader_random_seed: 111 # For dataset shuffling

load_pattern:
type: "max_throughput"

client:
workers: 1 # Number of client workers

metrics:
collect:
- "throughput"
- "latency"
- "ttft"
- "tpot"

endpoint_config:
endpoint: "http://localhost:8000"
api_key: null
45 changes: 45 additions & 0 deletions examples/05_Llama3.1-8B_Example/online_llama3_8b_cnn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Online Benchmark
name: "online-llama3.1-8b-cnn-benchmark"
version: "1.0"
type: "online"

model_params:
name: meta-llama/Llama-3.1-8B-Instruct # Path to the model
temperature: 1.0
top_p: 1.0
max_new_tokens: 128

datasets:
- name: "perf-test"
type: "performance"
path: "data/cnn_dailymail_validation.json" # Path to the dataset. Note: This should be the file generated with download_cnndm.py
format: "json"
samples: 13368
parser:
prompt: "input"
eval_method: "rouge"

settings:
runtime:
min_duration_ms: 6000 # 6 seconds
max_duration_ms: 60000 # 1 minute
scheduler_random_seed: 137 # For Poisson/distribution sampling
dataloader_random_seed: 111 # For dataset shuffling

load_pattern:
type: "poisson"
target_qps: 3 # This is the target queries per second for the Poisson load pattern (Legacy Loadgen)

client:
workers: 1 # Number of client workers

metrics:
collect:
- "throughput"
- "latency"
- "ttft"
- "tpot"

endpoint_config:
endpoint: "http://localhost:8000"
api_key: null