Skip to content

Commit 7f821be

Browse files
Support Bagel Model (vllm-project#726)
Signed-off-by: princepride <wangzhipeng628@gmail.com> Co-authored-by: wzliu <wzliu@connect.hku.hk>
1 parent a54c323 commit 7f821be

File tree

30 files changed

+1837
-173
lines changed

30 files changed

+1837
-173
lines changed

.buildkite/scripts/simple_test.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,4 @@ VENV_PYTHON="${VENV_DIR}/bin/python"
5454
"${VENV_PYTHON}" -m pytest -v -s tests/diffusion/cache/
5555
"${VENV_PYTHON}" -m pytest -v -s tests/model_executor/models/qwen2_5_omni/test_audio_length.py
5656
"${VENV_PYTHON}" -m pytest -v -s tests/worker/
57+
"${VENV_PYTHON}" -m pytest -v -s tests/distributed/omni_connectors/test_kv_flow.py
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
import argparse
2+
import os
3+
4+
5+
def parse_args():
6+
parser = argparse.ArgumentParser()
7+
parser.add_argument(
8+
"--model",
9+
default="ByteDance-Seed/BAGEL-7B-MoT",
10+
help="Path to merged model directory.",
11+
)
12+
parser.add_argument("--prompts", nargs="+", default=None, help="Input text prompts.")
13+
parser.add_argument(
14+
"--txt-prompts",
15+
type=str,
16+
default=None,
17+
help="Path to a .txt file with one prompt per line (preferred).",
18+
)
19+
parser.add_argument("--prompt_type", default="text", choices=["text"])
20+
21+
parser.add_argument(
22+
"--modality",
23+
default="text2img",
24+
choices=["text2img", "img2img", "img2text", "text2text"],
25+
help="Modality mode to control stage execution.",
26+
)
27+
28+
parser.add_argument(
29+
"--image-path",
30+
type=str,
31+
default=None,
32+
help="Path to input image for img2img.",
33+
)
34+
35+
# OmniLLM init args
36+
parser.add_argument("--enable-stats", action="store_true", default=False)
37+
parser.add_argument("--init-sleep-seconds", type=int, default=20)
38+
parser.add_argument("--batch-timeout", type=int, default=5)
39+
parser.add_argument("--init-timeout", type=int, default=300)
40+
parser.add_argument("--shm-threshold-bytes", type=int, default=65536)
41+
parser.add_argument("--worker-backend", type=str, default="process", choices=["process", "ray"])
42+
parser.add_argument("--ray-address", type=str, default=None)
43+
parser.add_argument("--stage-configs-path", type=str, default=None)
44+
parser.add_argument("--steps", type=int, default=50, help="Number of inference steps.")
45+
46+
args = parser.parse_args()
47+
return args
48+
49+
50+
def main():
51+
args = parse_args()
52+
model_name = args.model
53+
try:
54+
# Preferred: load from txt file (one prompt per line)
55+
if getattr(args, "txt_prompts", None) and args.prompt_type == "text":
56+
with open(args.txt_prompts, encoding="utf-8") as f:
57+
lines = [ln.strip() for ln in f.readlines()]
58+
args.prompts = [ln for ln in lines if ln != ""]
59+
print(f"[Info] Loaded {len(args.prompts)} prompts from {args.txt_prompts}")
60+
except Exception as e:
61+
print(f"[Error] Failed to load prompts: {e}")
62+
raise
63+
64+
if args.prompts is None:
65+
# Default prompt for text2img test if none provided
66+
args.prompts = ["<|im_start|>A cute cat<|im_end|>"]
67+
print(f"[Info] No prompts provided, using default: {args.prompts}")
68+
omni_outputs = []
69+
70+
from PIL import Image
71+
72+
if args.modality == "img2img":
73+
from PIL import Image
74+
75+
from vllm_omni.entrypoints.omni_diffusion import OmniDiffusion
76+
77+
print("[Info] Running in img2img mode (Stage 1 only)")
78+
client = OmniDiffusion(model=model_name)
79+
80+
generate_kwargs = {
81+
"prompt": args.prompts,
82+
"seed": 52,
83+
"need_kv_receive": False,
84+
"num_inference_steps": args.steps,
85+
}
86+
87+
if args.image_path:
88+
if os.path.exists(args.image_path):
89+
loaded_image = Image.open(args.image_path).convert("RGB")
90+
generate_kwargs["pil_image"] = loaded_image
91+
else:
92+
print(f"[Warning] Image path {args.image_path} does not exist.")
93+
94+
result = client.generate(**generate_kwargs)
95+
96+
# Ensure result is a list for iteration
97+
if not isinstance(result, list):
98+
omni_outputs = [result]
99+
else:
100+
omni_outputs = result
101+
102+
else:
103+
import copy
104+
105+
from vllm_omni.entrypoints.omni import Omni
106+
107+
omni_kwargs = {}
108+
if args.stage_configs_path:
109+
omni_kwargs["stage_configs_path"] = args.stage_configs_path
110+
111+
omni_kwargs.update(
112+
{
113+
"log_stats": args.enable_stats,
114+
"init_sleep_seconds": args.init_sleep_seconds,
115+
"batch_timeout": args.batch_timeout,
116+
"init_timeout": args.init_timeout,
117+
"shm_threshold_bytes": args.shm_threshold_bytes,
118+
"worker_backend": args.worker_backend,
119+
"ray_address": args.ray_address,
120+
}
121+
)
122+
123+
omni = Omni(model=model_name, **omni_kwargs)
124+
125+
formatted_prompts = []
126+
for p in args.prompts:
127+
if args.modality == "img2text":
128+
if args.image_path:
129+
loaded_image = Image.open(args.image_path).convert("RGB")
130+
final_prompt_text = f"<|im_start|>user\n<|image_pad|>\n{p}<|im_end|>\n<|im_start|>assistant\n"
131+
prompt_dict = {
132+
"prompt": final_prompt_text,
133+
"multi_modal_data": {"image": loaded_image},
134+
"modalities": ["text"],
135+
}
136+
formatted_prompts.append(prompt_dict)
137+
elif args.modality == "text2text":
138+
final_prompt_text = f"<|im_start|>user\n{p}<|im_end|>\n<|im_start|>assistant\n"
139+
prompt_dict = {"prompt": final_prompt_text, "modalities": ["text"]}
140+
formatted_prompts.append(prompt_dict)
141+
else:
142+
# text2img
143+
final_prompt_text = f"<|im_start|>{p}<|im_end|>"
144+
prompt_dict = {"prompt": final_prompt_text, "modalities": ["image"]}
145+
formatted_prompts.append(prompt_dict)
146+
147+
params_list = copy.deepcopy(omni.default_sampling_params_list)
148+
if args.modality == "text2img":
149+
params_list[0]["max_tokens"] = 1
150+
if len(params_list) > 1:
151+
params_list[1]["num_inference_steps"] = args.steps
152+
153+
omni_outputs = list(omni.generate(prompts=formatted_prompts, sampling_params_list=params_list))
154+
155+
for i, req_output in enumerate(omni_outputs):
156+
images = getattr(req_output, "images", None)
157+
if not images and hasattr(req_output, "output"):
158+
if isinstance(req_output.output, list):
159+
images = req_output.output
160+
else:
161+
images = [req_output.output]
162+
163+
if images:
164+
for j, img in enumerate(images):
165+
img.save(f"output_{i}_{j}.png")
166+
167+
if hasattr(req_output, "request_output") and req_output.request_output:
168+
for stage_out in req_output.request_output:
169+
if hasattr(stage_out, "images") and stage_out.images:
170+
for k, img in enumerate(stage_out.images):
171+
save_path = f"output_{i}_stage_{getattr(stage_out, 'stage_id', '?')}_{k}.png"
172+
img.save(save_path)
173+
print(f"[Info] Saved stage output image to {save_path}")
174+
175+
print(omni_outputs)
176+
177+
178+
if __name__ == "__main__":
179+
main()
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
prompt="<|im_start|>user\nWhat is the capital of France?<|im_end|>\n<|im_start|>assistant\n"
2+
3+
python end2end.py --model ByteDance-Seed/BAGEL-7B-MoT \
4+
--prompt_type text \
5+
--init-sleep-seconds 0 \
6+
--prompts ${prompt}
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Bagel OpenAI-compatible chat client for image generation and multimodal tasks.
4+
5+
Usage:
6+
python openai_chat_client.py --prompt "A cute cat" --output output.png
7+
python openai_chat_client.py --prompt "Describe this image" --image-url https://example.com/image.png
8+
"""
9+
10+
import argparse
11+
import base64
12+
from pathlib import Path
13+
14+
import requests
15+
16+
17+
def generate_image(
18+
prompt: str,
19+
server_url: str = "http://localhost:8091",
20+
image_url: str | None = None,
21+
height: int | None = None,
22+
width: int | None = None,
23+
steps: int | None = None,
24+
seed: int | None = None,
25+
negative_prompt: str | None = None,
26+
modality: str = "text2img", # "text2img" (default), "img2img", "img2text", "text2text"
27+
) -> bytes | str | None:
28+
"""Generate an image or text using the chat completions API.
29+
30+
Args:
31+
prompt: Text description or prompt
32+
server_url: Server URL
33+
image_url: URL or path to input image (for img2img/img2text)
34+
height: Image height in pixels
35+
width: Image width in pixels
36+
steps: Number of inference steps
37+
seed: Random seed
38+
negative_prompt: Negative prompt
39+
modality: Task modality hint
40+
41+
Returns:
42+
Image bytes (for image outputs) or Text string (for text outputs) or None if failed
43+
"""
44+
45+
# Construct Message Content
46+
content = [{"type": "text", "text": f"<|im_start|>{prompt}<|im_end|>"}]
47+
48+
if image_url:
49+
# Check if local file
50+
if Path(image_url).exists():
51+
with open(image_url, "rb") as f:
52+
b64_data = base64.b64encode(f.read()).decode("utf-8")
53+
final_image_url = f"data:image/jpeg;base64,{b64_data}"
54+
else:
55+
final_image_url = image_url
56+
57+
content.append({"type": "image_url", "image_url": {"url": final_image_url}})
58+
59+
messages = [{"role": "user", "content": content}]
60+
61+
# Build request payload with all parameters at top level
62+
# Note: vLLM ignores "extra_body", so we put parameters directly in the payload
63+
payload = {"messages": messages}
64+
65+
# Set output modalities at top level
66+
if modality == "text2img" or modality == "img2img":
67+
payload["modalities"] = ["image"]
68+
elif modality == "img2text" or modality == "text2text":
69+
payload["modalities"] = ["text"]
70+
71+
# Add generation parameters directly to payload
72+
if height is not None:
73+
payload["height"] = height
74+
if width is not None:
75+
payload["width"] = width
76+
if steps is not None:
77+
payload["num_inference_steps"] = steps
78+
if seed is not None:
79+
payload["seed"] = seed
80+
if negative_prompt:
81+
payload["negative_prompt"] = negative_prompt
82+
83+
# Send request
84+
try:
85+
print(f"Sending request to {server_url} with modality {modality}...")
86+
response = requests.post(
87+
f"{server_url}/v1/chat/completions",
88+
headers={"Content-Type": "application/json"},
89+
json=payload,
90+
timeout=300,
91+
)
92+
response.raise_for_status()
93+
data = response.json()
94+
95+
# Extract content - check ALL choices since server may return multiple
96+
# (e.g., text in choices[0], image in choices[1])
97+
choices = data.get("choices", [])
98+
99+
# First pass: look for image output in any choice
100+
for choice in choices:
101+
choice_content = choice.get("message", {}).get("content")
102+
103+
# Handle Image Output
104+
if isinstance(choice_content, list) and len(choice_content) > 0:
105+
first_item = choice_content[0]
106+
if isinstance(first_item, dict) and "image_url" in first_item:
107+
img_url_str = first_item["image_url"].get("url", "")
108+
if img_url_str.startswith("data:image"):
109+
_, b64_data = img_url_str.split(",", 1)
110+
return base64.b64decode(b64_data)
111+
112+
# Second pass: look for text output if no image found
113+
for choice in choices:
114+
choice_content = choice.get("message", {}).get("content")
115+
if isinstance(choice_content, str) and choice_content:
116+
return choice_content
117+
118+
print(f"Unexpected response format: {choices}")
119+
return None
120+
121+
except Exception as e:
122+
print(f"Error: {e}")
123+
return None
124+
125+
126+
def main():
127+
parser = argparse.ArgumentParser(description="Bagel multimodal chat client")
128+
parser.add_argument("--prompt", "-p", default="<|im_start|>A cute cat<|im_end|>", help="Text prompt")
129+
parser.add_argument("--output", "-o", default="bagel_output.png", help="Output file (for image results)")
130+
parser.add_argument("--server", "-s", default="http://localhost:8091", help="Server URL")
131+
132+
# Modality Control
133+
parser.add_argument("--image-url", "-i", type=str, help="Input image URL or local path")
134+
parser.add_argument(
135+
"--modality",
136+
"-m",
137+
default="text2img",
138+
choices=["text2img", "img2img", "img2text", "text2text"],
139+
help="Task modality",
140+
)
141+
142+
# Generation Params
143+
parser.add_argument("--height", type=int, default=512, help="Image height")
144+
parser.add_argument("--width", type=int, default=512, help="Image width")
145+
parser.add_argument("--steps", type=int, default=25, help="Inference steps")
146+
parser.add_argument("--seed", type=int, default=42, help="Random seed")
147+
parser.add_argument("--negative", help="Negative prompt")
148+
149+
args = parser.parse_args()
150+
151+
print(f"Mode: {args.modality}")
152+
if args.image_url:
153+
print(f"Input Image: {args.image_url}")
154+
155+
result = generate_image(
156+
prompt=args.prompt,
157+
server_url=args.server,
158+
image_url=args.image_url,
159+
height=args.height,
160+
width=args.width,
161+
steps=args.steps,
162+
seed=args.seed,
163+
negative_prompt=args.negative,
164+
modality=args.modality,
165+
)
166+
167+
if result:
168+
if isinstance(result, bytes):
169+
# It's an image
170+
output_path = Path(args.output)
171+
output_path.write_bytes(result)
172+
print(f"Image saved to: {output_path}")
173+
print(f"Size: {len(result) / 1024:.1f} KB")
174+
elif isinstance(result, str):
175+
# It's text
176+
print("Response:")
177+
print(result)
178+
else:
179+
print("Failed to generate response")
180+
exit(1)
181+
182+
183+
if __name__ == "__main__":
184+
main()
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#!/bin/bash
2+
# Bagel online serving startup script
3+
4+
MODEL="${MODEL:-ByteDance-Seed/BAGEL-7B-MoT}"
5+
PORT="${PORT:-8091}"
6+
7+
echo "Starting Bagel server..."
8+
echo "Model: $MODEL"
9+
echo "Port: $PORT"
10+
11+
vllm serve "$MODEL" --omni \
12+
--port "$PORT"

0 commit comments

Comments
 (0)