Skip to content

Commit f9c69a8

Browse files
[BugFix] Standardize StableAudio audio output (vllm-project#842)
Signed-off-by: LudovicoYIN <hankeyin@gmail.com> Co-authored-by: Hongsheng Liu <liuhongsheng4@huawei.com>
1 parent 3d9cd16 commit f9c69a8

File tree

6 files changed

+90
-30
lines changed

6 files changed

+90
-30
lines changed

examples/offline_inference/text_to_audio/text_to_audio.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def main():
142142
generation_start = time.perf_counter()
143143

144144
# Generate audio
145-
audio = omni.generate(
145+
outputs = omni.generate(
146146
args.prompt,
147147
negative_prompt=args.negative_prompt,
148148
generator=generator,
@@ -166,6 +166,21 @@ def main():
166166
suffix = output_path.suffix or ".wav"
167167
stem = output_path.stem or "stable_audio_output"
168168

169+
# Extract audio from omni.generate() outputs
170+
if not outputs:
171+
raise ValueError("No output generated from omni.generate()")
172+
173+
output = outputs[0]
174+
if not hasattr(output, "request_output") or not output.request_output:
175+
raise ValueError("No request_output found in OmniRequestOutput")
176+
request_output = output.request_output[0]
177+
if not hasattr(request_output, "multimodal_output"):
178+
raise ValueError("No multimodal_output found in request_output")
179+
180+
audio = request_output.multimodal_output.get("audio")
181+
if audio is None:
182+
raise ValueError("No audio output found in request_output")
183+
169184
# Handle different output formats
170185
if isinstance(audio, torch.Tensor):
171186
audio = audio.cpu().float().numpy()

tests/e2e/offline_inference/test_stable_audio_model.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,14 @@ def test_stable_audio_model(model_name: str):
4444
# Extract audio from OmniRequestOutput
4545
assert outputs is not None
4646
first_output = outputs[0]
47-
assert first_output.final_output_type == "image" # Generic output type
47+
assert first_output.final_output_type == "image"
4848
assert hasattr(first_output, "request_output") and first_output.request_output
4949

5050
req_out = first_output.request_output[0]
5151
assert isinstance(req_out, OmniRequestOutput)
52-
assert hasattr(req_out, "images") and len(req_out.images) >= 1
53-
54-
# For stable audio, the "images" field contains audio numpy arrays
55-
audio = req_out.images[0]
52+
assert req_out.final_output_type == "audio"
53+
assert hasattr(req_out, "multimodal_output") and req_out.multimodal_output
54+
audio = req_out.multimodal_output.get("audio")
5655
assert isinstance(audio, np.ndarray)
5756
# audio shape: (batch, channels, samples)
5857
# For stable-audio-open-1.0: sample_rate=44100, so 2 seconds = 88200 samples

vllm_omni/diffusion/diffusion_engine.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ def supports_image_input(model_class_name: str) -> bool:
2929
return bool(getattr(model_cls, "support_image_input", False))
3030

3131

32+
def supports_audio_output(model_class_name: str) -> bool:
33+
model_cls = DiffusionModelRegistry._try_load_model_cls(model_class_name)
34+
if model_cls is None:
35+
return False
36+
return bool(getattr(model_cls, "support_audio_output", False))
37+
38+
3239
class DiffusionEngine:
3340
"""The diffusion engine for vLLM-Omni diffusion models."""
3441

@@ -86,14 +93,14 @@ def step(self, requests: list[OmniDiffusionRequest]):
8693
return None
8794

8895
postprocess_start_time = time.time()
89-
images = self.post_process_func(output.output) if self.post_process_func is not None else output.output
96+
outputs = self.post_process_func(output.output) if self.post_process_func is not None else output.output
9097
postprocess_time = time.time() - postprocess_start_time
9198
logger.info(f"Post-processing completed in {postprocess_time:.4f} seconds")
9299

93100
# Convert to OmniRequestOutput format
94-
# Ensure images is a list
95-
if not isinstance(images, list):
96-
images = [images] if images is not None else []
101+
# Ensure outputs is a list
102+
if not isinstance(outputs, list):
103+
outputs = [outputs] if outputs is not None else []
97104

98105
# Handle single request or multiple requests
99106
if len(requests) == 1:
@@ -108,18 +115,30 @@ def step(self, requests: list[OmniDiffusionRequest]):
108115
if output.trajectory_timesteps is not None:
109116
metrics["trajectory_timesteps"] = output.trajectory_timesteps
110117

111-
return OmniRequestOutput.from_diffusion(
112-
request_id=request_id,
113-
images=images,
114-
prompt=prompt,
115-
metrics=metrics,
116-
latents=output.trajectory_latents,
117-
)
118+
if supports_audio_output(self.od_config.model_class_name):
119+
audio_payload = outputs[0] if len(outputs) == 1 else outputs
120+
return OmniRequestOutput.from_diffusion(
121+
request_id=request_id,
122+
images=[],
123+
prompt=prompt,
124+
metrics=metrics,
125+
latents=output.trajectory_latents,
126+
multimodal_output={"audio": audio_payload},
127+
final_output_type="audio",
128+
)
129+
else:
130+
return OmniRequestOutput.from_diffusion(
131+
request_id=request_id,
132+
images=outputs,
133+
prompt=prompt,
134+
metrics=metrics,
135+
latents=output.trajectory_latents,
136+
)
118137
else:
119138
# Multiple requests: return list of OmniRequestOutput
120139
# Split images based on num_outputs_per_prompt for each request
121140
results = []
122-
image_idx = 0
141+
output_idx = 0
123142

124143
for request in requests:
125144
request_id = request.request_id or ""
@@ -129,22 +148,38 @@ def step(self, requests: list[OmniDiffusionRequest]):
129148

130149
# Get images for this request
131150
num_outputs = request.num_outputs_per_prompt
132-
request_images = images[image_idx : image_idx + num_outputs] if image_idx < len(images) else []
133-
image_idx += num_outputs
151+
request_outputs = (
152+
outputs[output_idx : output_idx + num_outputs] if output_idx < len(outputs) else []
153+
)
154+
output_idx += num_outputs
134155

135156
metrics = {}
136157
if output.trajectory_timesteps is not None:
137158
metrics["trajectory_timesteps"] = output.trajectory_timesteps
138159

139-
results.append(
140-
OmniRequestOutput.from_diffusion(
141-
request_id=request_id,
142-
images=request_images,
143-
prompt=prompt,
144-
metrics=metrics,
145-
latents=output.trajectory_latents,
160+
if supports_audio_output(self.od_config.model_class_name):
161+
audio_payload = request_outputs[0] if len(request_outputs) == 1 else request_outputs
162+
results.append(
163+
OmniRequestOutput.from_diffusion(
164+
request_id=request_id,
165+
images=[],
166+
prompt=prompt,
167+
metrics=metrics,
168+
latents=output.trajectory_latents,
169+
multimodal_output={"audio": audio_payload},
170+
final_output_type="audio",
171+
)
172+
)
173+
else:
174+
results.append(
175+
OmniRequestOutput.from_diffusion(
176+
request_id=request_id,
177+
images=request_outputs,
178+
prompt=prompt,
179+
metrics=metrics,
180+
latents=output.trajectory_latents,
181+
)
146182
)
147-
)
148183

149184
return results
150185
except Exception as e:

vllm_omni/diffusion/models/interface.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@
1010
@runtime_checkable
1111
class SupportImageInput(Protocol):
1212
support_image_input: ClassVar[bool] = True
13+
14+
15+
@runtime_checkable
16+
class SupportAudioOutput(Protocol):
17+
support_audio_output: ClassVar[bool] = True

vllm_omni/diffusion/models/stable_audio/pipeline_stable_audio.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from vllm_omni.diffusion.data import DiffusionOutput, OmniDiffusionConfig
2828
from vllm_omni.diffusion.distributed.utils import get_local_device
2929
from vllm_omni.diffusion.model_loader.diffusers_loader import DiffusersPipelineLoader
30+
from vllm_omni.diffusion.models.interface import SupportAudioOutput
3031
from vllm_omni.diffusion.models.stable_audio.stable_audio_transformer import StableAudioDiTModel
3132
from vllm_omni.diffusion.request import OmniDiffusionRequest
3233

@@ -57,7 +58,7 @@ def post_process_func(
5758
return post_process_func
5859

5960

60-
class StableAudioPipeline(nn.Module):
61+
class StableAudioPipeline(nn.Module, SupportAudioOutput):
6162
"""
6263
Pipeline for text-to-audio generation using Stable Audio Open.
6364

vllm_omni/outputs.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class OmniRequestOutput:
5757
prompt: str | None = None
5858
latents: torch.Tensor | None = None
5959
metrics: dict[str, Any] = field(default_factory=dict)
60+
multimodal_output: dict[str, Any] = field(default_factory=dict)
6061

6162
@classmethod
6263
def from_pipeline(
@@ -91,6 +92,8 @@ def from_diffusion(
9192
prompt: str | None = None,
9293
metrics: dict[str, Any] | None = None,
9394
latents: torch.Tensor | None = None,
95+
multimodal_output: dict[str, Any] | None = None,
96+
final_output_type: str = "image",
9497
) -> "OmniRequestOutput":
9598
"""Create output from diffusion model.
9699
@@ -106,11 +109,12 @@ def from_diffusion(
106109
"""
107110
return cls(
108111
request_id=request_id,
109-
final_output_type="image",
112+
final_output_type=final_output_type,
110113
images=images,
111114
prompt=prompt,
112115
latents=latents,
113116
metrics=metrics or {},
117+
multimodal_output=multimodal_output or {},
114118
finished=True,
115119
)
116120

@@ -171,6 +175,7 @@ def __repr__(self) -> str:
171175
f"prompt={self.prompt!r}",
172176
f"latents={self.latents}",
173177
f"metrics={self.metrics}",
178+
f"multimodal_output={self.multimodal_output}",
174179
]
175180

176181
return f"OmniRequestOutput({', '.join(parts)})"

0 commit comments

Comments
 (0)