Skip to content

Commit 5466192

Browse files
celveclaude
andcommitted
[diffusion] feat: add rollout log_prob with flow-matching SDE/CPS support (sgl-project#18806)
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent e9b5706 commit 5466192

File tree

14 files changed

+229
-7
lines changed

14 files changed

+229
-7
lines changed

python/sglang/multimodal_gen/configs/sample/sampling_params.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ class SamplingParams:
159159
# Misc
160160
save_output: bool = True
161161
return_frames: bool = False
162+
rollout: bool = False
163+
rollout_sde_type: str = "sde"
164+
rollout_noise_level: float = 0.7
162165
return_trajectory_latents: bool = False # returns all latents for each timestep
163166
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
164167
# if True, disallow user params to override subclass-defined protected fields
@@ -306,6 +309,9 @@ def _finite_non_negative_float(
306309
_finite_non_negative_float(
307310
"guidance_rescale", self.guidance_rescale, allow_none=False
308311
)
312+
_finite_non_negative_float(
313+
"rollout_noise_level", self.rollout_noise_level, allow_none=False
314+
)
309315

310316
if self.cfg_normalization is None:
311317
self.cfg_normalization = 0.0
@@ -808,6 +814,25 @@ def add_cli_args(parser: Any) -> Any:
808814
default=SamplingParams.return_trajectory_latents,
809815
help="Whether to return the trajectory",
810816
)
817+
parser.add_argument(
818+
"--rollout",
819+
action="store_true",
820+
default=SamplingParams.rollout,
821+
help="Enable rollout mode and return per-step log_prob trajectory",
822+
)
823+
parser.add_argument(
824+
"--rollout-sde-type",
825+
type=str,
826+
choices=["sde", "cps"],
827+
default=SamplingParams.rollout_sde_type,
828+
help="Rollout step objective type used in log-prob computation.",
829+
)
830+
parser.add_argument(
831+
"--rollout-noise-level",
832+
type=float,
833+
default=SamplingParams.rollout_noise_level,
834+
help="Noise level used by rollout SDE/CPS step objective.",
835+
)
811836
parser.add_argument(
812837
"--return-trajectory-decoded",
813838
action="store_true",

python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def generate(
216216
),
217217
trajectory_latents=output_batch.trajectory_latents,
218218
trajectory_timesteps=output_batch.trajectory_timesteps,
219+
trajectory_log_probs=output_batch.trajectory_log_probs,
219220
trajectory_decoded=output_batch.trajectory_decoded,
220221
)
221222

python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ async def generations(
129129
true_cfg_scale=request.true_cfg_scale,
130130
negative_prompt=request.negative_prompt,
131131
enable_teacache=request.enable_teacache,
132+
rollout=request.rollout,
133+
rollout_sde_type=request.rollout_sde_type,
134+
rollout_noise_level=request.rollout_noise_level,
132135
output_compression=request.output_compression,
133136
output_quality=request.output_quality,
134137
)

python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class ImageGenerationsRequest(BaseModel):
4646
output_quality: Optional[str] = "default"
4747
output_compression: Optional[int] = None
4848
enable_teacache: Optional[bool] = False
49+
rollout: Optional[bool] = False
50+
rollout_sde_type: Optional[str] = "sde"
51+
rollout_noise_level: Optional[float] = 0.7
4952
diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend
5053

5154

@@ -98,6 +101,9 @@ class VideoGenerationsRequest(BaseModel):
98101
output_quality: Optional[str] = "default"
99102
output_compression: Optional[int] = None
100103
output_path: Optional[str] = None
104+
rollout: Optional[bool] = False
105+
rollout_sde_type: Optional[str] = "sde"
106+
rollout_noise_level: Optional[float] = 0.7
101107
diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend
102108

103109

python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,9 @@ def _build_video_sampling_params(request_id: str, request: VideoGenerationsReque
7575
frame_interpolation_exp=request.frame_interpolation_exp,
7676
frame_interpolation_scale=request.frame_interpolation_scale,
7777
frame_interpolation_model_path=request.frame_interpolation_model_path,
78+
rollout=request.rollout,
79+
rollout_sde_type=request.rollout_sde_type,
80+
rollout_noise_level=request.rollout_noise_level,
7881
output_path=request.output_path,
7982
output_compression=request.output_compression,
8083
output_quality=request.output_quality,
@@ -181,6 +184,9 @@ async def create_video(
181184
frame_interpolation_exp: Optional[int] = Form(1),
182185
frame_interpolation_scale: Optional[float] = Form(1.0),
183186
frame_interpolation_model_path: Optional[str] = Form(None),
187+
rollout: Optional[bool] = Form(False),
188+
rollout_sde_type: Optional[str] = Form("sde"),
189+
rollout_noise_level: Optional[float] = Form(0.7),
184190
output_quality: Optional[str] = Form("default"),
185191
output_compression: Optional[int] = Form(None),
186192
extra_body: Optional[str] = Form(None),
@@ -256,6 +262,9 @@ async def create_video(
256262
frame_interpolation_exp=frame_interpolation_exp,
257263
frame_interpolation_scale=frame_interpolation_scale,
258264
frame_interpolation_model_path=frame_interpolation_model_path,
265+
rollout=rollout,
266+
rollout_sde_type=rollout_sde_type,
267+
rollout_noise_level=rollout_noise_level,
259268
output_compression=output_compression,
260269
output_quality=output_quality,
261270
**(

python/sglang/multimodal_gen/runtime/entrypoints/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ class GenerationResult:
108108
metrics: dict = field(default_factory=dict)
109109
trajectory_latents: Any = None
110110
trajectory_timesteps: Any = None
111+
trajectory_log_probs: Any = None
111112
trajectory_decoded: Any = None
112113
prompt_index: int = 0
113114
output_file_path: str | None = None

python/sglang/multimodal_gen/runtime/managers/gpu_worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,7 @@ def execute_forward(self, batch: List[Req]) -> OutputBatch:
233233
metrics=result.metrics,
234234
trajectory_timesteps=getattr(result, "trajectory_timesteps", None),
235235
trajectory_latents=getattr(result, "trajectory_latents", None),
236+
trajectory_log_probs=getattr(result, "trajectory_log_probs", None),
236237
noise_pred=getattr(result, "noise_pred", None),
237238
trajectory_decoded=getattr(result, "trajectory_decoded", None),
238239
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# SPDX-License-Identifier: Apache-2.0
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
"""Flow-matching rollout step utilities for log-prob computation."""
3+
4+
import math
5+
from typing import Any, Optional, Union
6+
7+
import torch
8+
from diffusers.utils.torch_utils import randn_tensor
9+
10+
11+
def _as_timestep_tensor(
12+
timestep: Union[float, torch.Tensor], batch_size: int, device: torch.device
13+
) -> torch.Tensor:
14+
"""Normalize timestep input to a 1D tensor on the target device."""
15+
if torch.is_tensor(timestep):
16+
ts = timestep.to(device=device)
17+
else:
18+
ts = torch.tensor([timestep], device=device)
19+
20+
if ts.ndim == 0:
21+
ts = ts.view(1)
22+
else:
23+
ts = ts.view(-1)
24+
25+
# Broadcast scalar timestep to match batch size.
26+
if ts.numel() == 1 and batch_size > 1:
27+
ts = ts.repeat(batch_size)
28+
return ts
29+
30+
31+
def sde_step_with_logprob(
32+
self: Any,
33+
model_output: torch.FloatTensor,
34+
timestep: Union[float, torch.FloatTensor],
35+
sample: torch.FloatTensor,
36+
noise_level: float = 0.7,
37+
prev_sample: Optional[torch.FloatTensor] = None,
38+
generator: Optional[Union[torch.Generator, list[torch.Generator]]] = None,
39+
sde_type: str = "sde",
40+
):
41+
"""Run one rollout step and compute per-sample log_prob.
42+
43+
sde_type="sde" uses the Gaussian transition objective.
44+
sde_type="cps" uses the simplified CPS objective.
45+
"""
46+
sample_dtype = sample.dtype
47+
model_output = model_output.float()
48+
sample = sample.float()
49+
if prev_sample is not None:
50+
prev_sample = prev_sample.float()
51+
52+
batch_size = sample.shape[0]
53+
timestep_tensor = _as_timestep_tensor(timestep, batch_size, sample.device)
54+
step_indices = torch.tensor(
55+
[self.index_for_timestep(t.to(self.timesteps.device)) for t in timestep_tensor],
56+
device=sample.device,
57+
dtype=torch.long,
58+
)
59+
prev_step_indices = (step_indices + 1).clamp_max(len(self.sigmas) - 1)
60+
step_indices = step_indices.to(device=self.sigmas.device)
61+
prev_step_indices = prev_step_indices.to(device=self.sigmas.device)
62+
63+
sigma = self.sigmas[step_indices].to(sample.device).to(sample.dtype)
64+
sigma_prev = self.sigmas[prev_step_indices].to(sample.device).to(sample.dtype)
65+
sigma = sigma.view(-1, *([1] * (sample.ndim - 1)))
66+
sigma_prev = sigma_prev.view(-1, *([1] * (sample.ndim - 1)))
67+
sigma_max = self.sigmas[min(1, len(self.sigmas) - 1)].to(
68+
device=sample.device, dtype=sample.dtype
69+
)
70+
dt = sigma_prev - sigma
71+
72+
if sde_type == "sde":
73+
denom_sigma = 1 - torch.where(sigma == 1, sigma_max, sigma)
74+
std_dev_t = torch.sqrt((sigma / denom_sigma).clamp_min(1e-12)) * noise_level
75+
prev_sample_mean = (
76+
sample * (1 + std_dev_t**2 / (2 * sigma) * dt)
77+
+ model_output * (1 + std_dev_t**2 * (1 - sigma) / (2 * sigma)) * dt
78+
)
79+
80+
sqrt_neg_dt = torch.sqrt((-dt).clamp_min(1e-12))
81+
if prev_sample is None:
82+
variance_noise = randn_tensor(
83+
model_output.shape,
84+
generator=generator,
85+
device=model_output.device,
86+
dtype=model_output.dtype,
87+
)
88+
prev_sample = prev_sample_mean + std_dev_t * sqrt_neg_dt * variance_noise
89+
90+
std = (std_dev_t * sqrt_neg_dt).clamp_min(1e-12)
91+
log_prob = (
92+
-((prev_sample.detach() - prev_sample_mean) ** 2) / (2 * (std**2))
93+
- torch.log(std)
94+
- torch.log(torch.sqrt(torch.as_tensor(2 * math.pi, device=std.device)))
95+
)
96+
elif sde_type == "cps":
97+
std_dev_t = sigma_prev * math.sin(noise_level * math.pi / 2)
98+
pred_original_sample = sample - sigma * model_output
99+
noise_estimate = sample + model_output * (1 - sigma)
100+
sigma_delta = (sigma_prev**2 - std_dev_t**2).clamp_min(0.0)
101+
prev_sample_mean = pred_original_sample * (
102+
1 - sigma_prev
103+
) + noise_estimate * torch.sqrt(sigma_delta)
104+
105+
if prev_sample is None:
106+
variance_noise = randn_tensor(
107+
model_output.shape,
108+
generator=generator,
109+
device=model_output.device,
110+
dtype=model_output.dtype,
111+
)
112+
prev_sample = prev_sample_mean + std_dev_t * variance_noise
113+
114+
# Keep the same simplified cps objective used in the original patch.
115+
log_prob = -((prev_sample.detach() - prev_sample_mean) ** 2)
116+
else:
117+
raise ValueError(f"Unsupported sde_type: {sde_type}")
118+
119+
log_prob = log_prob.mean(dim=tuple(range(1, log_prob.ndim)))
120+
return prev_sample.to(sample_dtype), log_prob, prev_sample_mean, std_dev_t

python/sglang/multimodal_gen/runtime/pipelines_core/schedule_batch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class Req:
136136

137137
trajectory_timesteps: list[torch.Tensor] | None = None
138138
trajectory_latents: torch.Tensor | None = None
139+
trajectory_log_probs: torch.Tensor | None = None
139140
trajectory_audio_latents: torch.Tensor | None = None
140141

141142
# Extra parameters that might be needed by specific pipeline implementations
@@ -334,6 +335,7 @@ class OutputBatch:
334335
audio_sample_rate: int | None = None
335336
trajectory_timesteps: list[torch.Tensor] | None = None
336337
trajectory_latents: torch.Tensor | None = None
338+
trajectory_log_probs: torch.Tensor | None = None
337339
trajectory_decoded: list[torch.Tensor] | None = None
338340
error: str | None = None
339341
output_file_paths: list[str] | None = None

0 commit comments

Comments
 (0)