Skip to content

Commit 9eeac24

Browse files
meghaagr13tzhouam
andauthored
[Feat] Phase 1 foundation types for multimodal output decoupling (vllm-project#1816)
Signed-off-by: Megha Agarwal <agarwalmegha1308@gmail.com> Signed-off-by: Megha Agarwal <agarwalmegha@microsoft.com> Co-authored-by: Zhou Taichang <tzhouam@connect.ust.hk>
1 parent 77d773a commit 9eeac24

File tree

4 files changed

+346
-0
lines changed

4 files changed

+346
-0
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""Unit tests for Phase 1 foundation types (RFC #1601).
2+
3+
Note: Uses importlib to load modules directly, bypassing the vllm_omni
4+
package __init__ which requires the vllm base package.
5+
"""
6+
7+
import importlib.util
8+
import sys
9+
from pathlib import Path
10+
11+
import pytest
12+
import torch
13+
14+
# ── Load modules without triggering vllm_omni.__init__ ─────────────
15+
16+
_ENGINE_DIR = Path(__file__).resolve().parents[2] / "vllm_omni" / "engine"
17+
18+
19+
def _load_module(name: str, filepath: Path):
20+
spec = importlib.util.spec_from_file_location(name, filepath)
21+
mod = importlib.util.module_from_spec(spec)
22+
sys.modules[name] = mod
23+
spec.loader.exec_module(mod)
24+
return mod
25+
26+
27+
_om_mod = _load_module(
28+
"vllm_omni.engine.output_modality",
29+
_ENGINE_DIR / "output_modality.py",
30+
)
31+
_mm_mod = _load_module(
32+
"vllm_omni.engine.mm_outputs",
33+
_ENGINE_DIR / "mm_outputs.py",
34+
)
35+
36+
OutputModality = _om_mod.OutputModality
37+
TensorAccumulationStrategy = _om_mod.TensorAccumulationStrategy
38+
get_accumulation_strategy = _om_mod.get_accumulation_strategy
39+
MultimodalPayload = _mm_mod.MultimodalPayload
40+
MultimodalCompletionOutput = _mm_mod.MultimodalCompletionOutput
41+
42+
43+
def test_output_modality_parsing_and_flags():
44+
"""Test OutputModality enum: from_string, aliases, compounds, properties, and accumulation strategy."""
45+
# Defaults
46+
assert OutputModality.from_string(None) == OutputModality.TEXT
47+
assert OutputModality.from_string("") == OutputModality.TEXT
48+
49+
# Direct names and case insensitivity
50+
assert OutputModality.from_string("image") == OutputModality.IMAGE
51+
assert OutputModality.from_string("Audio") == OutputModality.AUDIO
52+
53+
# Aliases
54+
assert OutputModality.from_string("speech") == OutputModality.AUDIO
55+
assert OutputModality.from_string("latents") == OutputModality.LATENT
56+
assert OutputModality.from_string("pixel_values") == OutputModality.IMAGE
57+
58+
# Compound
59+
compound = OutputModality.from_string("text+image")
60+
assert compound.has_text and compound.has_multimodal
61+
62+
# Flag properties
63+
assert OutputModality.TEXT.has_text and not OutputModality.TEXT.has_multimodal
64+
assert OutputModality.IMAGE.has_multimodal and not OutputModality.IMAGE.has_text
65+
66+
# Accumulation strategy
67+
assert get_accumulation_strategy(OutputModality.AUDIO) == TensorAccumulationStrategy.CONCAT_LAST
68+
assert get_accumulation_strategy(OutputModality.IMAGE) == TensorAccumulationStrategy.CONCAT_DIM0
69+
70+
# Unknown raises
71+
with pytest.raises(ValueError, match="Unknown modality"):
72+
OutputModality.from_string("video")
73+
74+
75+
def test_multimodal_payload_and_completion_output():
76+
"""Test MultimodalPayload and MultimodalCompletionOutput wrapper."""
77+
# Payload from_dict separates tensors and metadata
78+
data = {"waveform": torch.ones(1, 16000), "sample_rate": 16000}
79+
p = MultimodalPayload.from_dict(data)
80+
assert p is not None
81+
assert "waveform" in p.tensors and torch.equal(p.primary_tensor, data["waveform"])
82+
assert p.metadata["sample_rate"] == 16000
83+
assert not p.is_empty and len(p) == 1
84+
85+
# None/empty returns None
86+
assert MultimodalPayload.from_dict(None) is None
87+
assert MultimodalPayload.from_dict({}) is None
88+
89+
wrapper = MultimodalCompletionOutput(
90+
multimodal_output=p,
91+
index=0,
92+
text="hello",
93+
token_ids=[],
94+
cumulative_logprob=None,
95+
logprobs=None,
96+
)
97+
assert wrapper.text == "hello"
98+
assert wrapper.multimodal_output is p
99+
100+
101+
def test_output_modality_printed_examples(capsys):
102+
"""Printed examples for output modality types."""
103+
print("\n=== OutputModality Parsing ===")
104+
for s in [None, "", "image", "Audio", "speech", "latents", "pixel_values", "text+image"]:
105+
print(f" from_string({s!r:20s}) -> {OutputModality.from_string(s)}")
106+
107+
print("\n=== Flag Properties ===")
108+
for m in [
109+
OutputModality.TEXT,
110+
OutputModality.IMAGE,
111+
OutputModality.AUDIO,
112+
OutputModality.TEXT | OutputModality.IMAGE,
113+
]:
114+
print(f" {str(m):40s} has_text={m.has_text} has_multimodal={m.has_multimodal}")
115+
116+
print("\n=== Accumulation Strategies ===")
117+
for m in [OutputModality.AUDIO, OutputModality.IMAGE, OutputModality.LATENT]:
118+
print(f" {str(m):30s} -> {get_accumulation_strategy(m)}")
119+
120+
print("\n=== MultimodalPayload ===")
121+
data = {"waveform": torch.ones(1, 16000), "sample_rate": 16000}
122+
p = MultimodalPayload.from_dict(data)
123+
print(" from_dict({waveform: tensor, sample_rate: 16000})")
124+
print(f" tensors keys : {list(p.tensors.keys())}")
125+
print(f" primary_tensor: shape={p.primary_tensor.shape}, dtype={p.primary_tensor.dtype}")
126+
print(f" metadata : {p.metadata}")
127+
print(f" is_empty={p.is_empty}, len={len(p)}")
128+
print(f" from_dict(None) -> {MultimodalPayload.from_dict(None)}")
129+
print(f" from_dict({{}}) -> {MultimodalPayload.from_dict({})}")
130+
131+
print("\n=== MultimodalCompletionOutput ===")
132+
wrapper = MultimodalCompletionOutput(
133+
multimodal_output=p,
134+
index=0,
135+
text="hello",
136+
token_ids=[],
137+
cumulative_logprob=None,
138+
logprobs=None,
139+
)
140+
print(f" text : {wrapper.text}")
141+
print(f" index : {wrapper.index}")
142+
print(f" multimodal_output: {wrapper.multimodal_output}")
143+
print(f" repr : {wrapper!r}")
144+
145+
print("\n=== Unknown Modality ===")
146+
try:
147+
OutputModality.from_string("video")
148+
except ValueError as e:
149+
print(f' from_string("video") raised ValueError: {e}')
150+
151+
captured = capsys.readouterr()
152+
assert "OutputModality Parsing" in captured.out
153+
assert "MultimodalPayload" in captured.out

vllm_omni/engine/arg_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm.logger import init_logger
88

99
from vllm_omni.config import OmniModelConfig
10+
from vllm_omni.engine.output_modality import OutputModality
1011
from vllm_omni.plugins import load_omni_general_plugins
1112

1213
logger = init_logger(__name__)
@@ -143,3 +144,8 @@ def create_model_config(self) -> OmniModelConfig:
143144
task_type=self.task_type,
144145
)
145146
return omni_config
147+
148+
@property
149+
def output_modality(self) -> OutputModality:
150+
"""Parse engine_output_type into a type-safe OutputModality flag."""
151+
return OutputModality.from_string(self.engine_output_type)

vllm_omni/engine/mm_outputs.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Multimodal output data structures for vLLM-Omni.
2+
3+
This module defines structured types for multimodal outputs.
4+
5+
"""
6+
7+
from __future__ import annotations
8+
9+
from dataclasses import dataclass, field
10+
from typing import Any
11+
12+
import torch
13+
from vllm.outputs import CompletionOutput
14+
15+
16+
@dataclass
17+
class MultimodalPayload:
18+
"""Structured multimodal output payload.
19+
20+
Attributes:
21+
tensors: Dictionary mapping modality/key names to their tensors.
22+
metadata: Optional dictionary for non-tensor metadata
23+
(e.g., sample rate for audio, image dimensions).
24+
"""
25+
26+
tensors: dict[str, torch.Tensor] = field(default_factory=dict)
27+
metadata: dict[str, Any] = field(default_factory=dict)
28+
29+
@property
30+
def primary_tensor(self) -> torch.Tensor | None:
31+
"""Return the first tensor in the payload, or None if empty."""
32+
if self.tensors:
33+
return next(iter(self.tensors.values()))
34+
return None
35+
36+
@property
37+
def is_empty(self) -> bool:
38+
"""Return True if the payload has no tensors."""
39+
return len(self.tensors) == 0
40+
41+
def get(self, key: str) -> torch.Tensor | None:
42+
"""Get a tensor by key, returning None if not found."""
43+
return self.tensors.get(key)
44+
45+
def __contains__(self, key: str) -> bool:
46+
return key in self.tensors
47+
48+
def __len__(self) -> int:
49+
return len(self.tensors)
50+
51+
@classmethod
52+
def from_dict(cls, data: dict[str, Any] | None) -> MultimodalPayload | None:
53+
"""Create a MultimodalPayload from a raw dictionary.
54+
55+
Separates torch.Tensor values into tensors and everything
56+
else into metadata.
57+
"""
58+
if not data:
59+
return None
60+
tensors: dict[str, torch.Tensor] = {}
61+
metadata: dict[str, Any] = {}
62+
for k, v in data.items():
63+
if isinstance(v, torch.Tensor):
64+
tensors[k] = v
65+
else:
66+
metadata[k] = v
67+
if not tensors and not metadata:
68+
return None
69+
return cls(tensors=tensors, metadata=metadata)
70+
71+
72+
@dataclass
73+
class MultimodalCompletionOutput(CompletionOutput):
74+
"""CompletionOutput with multimodal support.
75+
76+
Inherits all CompletionOutput fields and adds multimodal_output.
77+
As a CompletionOutput subclass, compatible with all existing vLLM consumers.
78+
"""
79+
80+
def __init__(
81+
self,
82+
multimodal_output: MultimodalPayload | None = None,
83+
**kwargs: Any,
84+
):
85+
super().__init__(**kwargs)
86+
self.multimodal_output = multimodal_output
87+
88+
def __repr__(self) -> str:
89+
base = super().__repr__()
90+
return f"{base[:-1]}, multimodal_output={self.multimodal_output!r})"
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Output modality types for vLLM-Omni.
2+
3+
This module defines the OutputModality enum and TensorAccumulationStrategy
4+
for type-safe multimodal output routing and tensor merging.
5+
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import re
11+
from enum import Enum, Flag, auto
12+
13+
_MODALITY_ALIASES: dict[str, str] = {
14+
"speech": "audio",
15+
"images": "image",
16+
"latents": "latent",
17+
"wav": "audio",
18+
"waveform": "audio",
19+
"pixel_values": "image",
20+
"pixels": "image",
21+
}
22+
23+
24+
class OutputModality(Flag):
25+
"""Bit-flag enum for output modalities.
26+
27+
Compose freely with ``|`` — no need to enumerate every combination.
28+
29+
Single: ``OutputModality.TEXT``, ``OutputModality.IMAGE``, ...
30+
Compound: ``OutputModality.TEXT | OutputModality.IMAGE`` (text+image)
31+
32+
Note: POOLING is intentionally excluded. Pooling/embedding is vLLM's
33+
native path (pooling_output → PoolingRequestOutput), handled entirely
34+
by the base OutputProcessor. vLLM-Omni's layer does not participate.
35+
"""
36+
37+
TEXT = auto()
38+
IMAGE = auto()
39+
AUDIO = auto()
40+
LATENT = auto()
41+
42+
@classmethod
43+
def from_string(cls, s: str | None) -> OutputModality:
44+
"""Parse a free-text modality string into an OutputModality flag.
45+
46+
Handles common aliases and compound strings separated by + or ,.
47+
48+
Examples::
49+
50+
OutputModality.from_string("text+image")
51+
# → OutputModality.TEXT | OutputModality.IMAGE
52+
"""
53+
if not s or not s.strip():
54+
return cls.TEXT
55+
56+
parts = [p.strip().lower() for p in re.split(r"[+,]", s.strip())]
57+
result = cls(0)
58+
for p in parts:
59+
p = _MODALITY_ALIASES.get(p, p)
60+
try:
61+
result |= cls[p.upper()]
62+
except KeyError:
63+
raise ValueError(f"Unknown modality: {p!r}. Supported: {[m.name.lower() for m in cls]}")
64+
return result
65+
66+
@property
67+
def has_text(self) -> bool:
68+
return OutputModality.TEXT in self
69+
70+
@property
71+
def has_multimodal(self) -> bool:
72+
return bool(self & ~OutputModality.TEXT)
73+
74+
75+
class TensorAccumulationStrategy(Enum):
76+
"""Strategy for merging incremental multimodal tensors."""
77+
78+
CONCAT_DIM0 = "concat_dim0"
79+
"""Concatenate along dimension 0. Used for image/latent tensors."""
80+
81+
CONCAT_LAST = "concat_last"
82+
"""Concatenate along the last dimension. Used for audio waveforms."""
83+
84+
APPEND_LIST = "append_list"
85+
"""Append to a list (no tensor concatenation)."""
86+
87+
REPLACE = "replace"
88+
"""Replace previous tensor entirely with the latest one."""
89+
90+
91+
def get_accumulation_strategy(modality: OutputModality) -> TensorAccumulationStrategy:
92+
"""Determine tensor merge strategy from the multimodal flags."""
93+
if OutputModality.AUDIO in modality:
94+
return TensorAccumulationStrategy.CONCAT_LAST
95+
if OutputModality.IMAGE in modality or OutputModality.LATENT in modality:
96+
return TensorAccumulationStrategy.CONCAT_DIM0
97+
return TensorAccumulationStrategy.CONCAT_DIM0 # default

0 commit comments

Comments
 (0)