Skip to content

Commit 2c3971f

Browse files
committed
local: merge PR sgl-project#18806 (rollout log_prob) + PR sgl-project#19153 (sleep/wake)
Cherry-picked from: - PR sgl-project#18806 (MikukuOvO): flow-matching SDE/CPS log_prob - PR sgl-project#19153 (Godmook): release/resume memory occupation Known issues: - t.item() in log_prob path causes GPU sync overhead - release_memory_occupation tags only supports "weights"
1 parent 43f8352 commit 2c3971f

File tree

16 files changed

+557
-7
lines changed

16 files changed

+557
-7
lines changed

python/sglang/multimodal_gen/configs/sample/sampling_params.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,9 @@ class SamplingParams:
146146
# Misc
147147
save_output: bool = True
148148
return_frames: bool = False
149+
rollout: bool = False
150+
rollout_sde_type: str = "sde"
151+
rollout_noise_level: float = 0.7
149152
return_trajectory_latents: bool = False # returns all latents for each timestep
150153
return_trajectory_decoded: bool = False # returns decoded latents for each timestep
151154
# if True, disallow user params to override subclass-defined protected fields
@@ -293,6 +296,9 @@ def _finite_non_negative_float(
293296
_finite_non_negative_float(
294297
"guidance_rescale", self.guidance_rescale, allow_none=False
295298
)
299+
_finite_non_negative_float(
300+
"rollout_noise_level", self.rollout_noise_level, allow_none=False
301+
)
296302

297303
if self.cfg_normalization is None:
298304
self.cfg_normalization = 0.0
@@ -743,6 +749,25 @@ def add_cli_args(parser: Any) -> Any:
743749
default=SamplingParams.return_trajectory_latents,
744750
help="Whether to return the trajectory",
745751
)
752+
parser.add_argument(
753+
"--rollout",
754+
action="store_true",
755+
default=SamplingParams.rollout,
756+
help="Enable rollout mode and return per-step log_prob trajectory",
757+
)
758+
parser.add_argument(
759+
"--rollout-sde-type",
760+
type=str,
761+
choices=["sde", "cps"],
762+
default=SamplingParams.rollout_sde_type,
763+
help="Rollout step objective type used in log-prob computation.",
764+
)
765+
parser.add_argument(
766+
"--rollout-noise-level",
767+
type=float,
768+
default=SamplingParams.rollout_noise_level,
769+
help="Noise level used by rollout SDE/CPS step objective.",
770+
)
746771
parser.add_argument(
747772
"--return-trajectory-decoded",
748773
action="store_true",

python/sglang/multimodal_gen/runtime/entrypoints/diffusion_generator.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
GenerationResult,
1919
ListLorasReq,
2020
MergeLoraWeightsReq,
21+
ReleaseMemoryOccupationReq,
22+
ResumeMemoryOccupationReq,
2123
SetLoraReq,
2224
ShutdownReq,
2325
UnmergeLoraWeightsReq,
@@ -213,6 +215,7 @@ def generate(
213215
),
214216
trajectory_latents=output_batch.trajectory_latents,
215217
trajectory_timesteps=output_batch.trajectory_timesteps,
218+
trajectory_log_probs=output_batch.trajectory_log_probs,
216219
trajectory_decoded=output_batch.trajectory_decoded,
217220
)
218221

@@ -452,6 +455,40 @@ def generate_with_lora(
452455
)
453456
)
454457

458+
def release_memory_occupation(self, tags: List[str] | None = None) -> dict:
459+
"""Release GPU memory (sleep). Offloads model weights to CPU.
460+
461+
Args:
462+
tags: Which memory regions to release. Currently only "weights" is
463+
supported for diffusion. If omitted, all regions are released.
464+
465+
Returns:
466+
dict with "success" and "message" keys.
467+
"""
468+
req = ReleaseMemoryOccupationReq(tags=tags)
469+
response = sync_scheduler_client.forward(req)
470+
if response.error:
471+
raise RuntimeError(f"Failed to release memory: {response.error}")
472+
logger.info("Successfully released GPU memory occupation (sleeping).")
473+
return response.output
474+
475+
def resume_memory_occupation(self, tags: List[str] | None = None) -> dict:
476+
"""Resume GPU memory (wake up). Loads model weights back to GPU.
477+
478+
Args:
479+
tags: Which memory regions to resume. Currently only "weights" is
480+
supported for diffusion. If omitted, all regions are resumed.
481+
482+
Returns:
483+
dict with "success" and "message" keys.
484+
"""
485+
req = ResumeMemoryOccupationReq(tags=tags)
486+
response = sync_scheduler_client.forward(req)
487+
if response.error:
488+
raise RuntimeError(f"Failed to resume memory: {response.error}")
489+
logger.info("Successfully resumed GPU memory occupation (waking up).")
490+
return response.output
491+
455492
def shutdown(self):
456493
"""
457494
Shutdown the generator.

python/sglang/multimodal_gen/runtime/entrypoints/http_server.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from sglang.multimodal_gen.runtime.entrypoints.openai.utils import build_sampling_params
2020
from sglang.multimodal_gen.runtime.entrypoints.post_training import weights_api
2121
from sglang.multimodal_gen.runtime.entrypoints.utils import (
22+
ReleaseMemoryOccupationReq,
23+
ResumeMemoryOccupationReq,
2224
prepare_request,
2325
save_outputs,
2426
)
@@ -101,6 +103,78 @@ async def health_generate():
101103
return {"status": "ok"}
102104

103105

106+
def _extract_tags_from_body(body: dict) -> list[str] | None:
107+
"""Return the ``tags`` field from a parsed request body, or ``None``."""
108+
if not isinstance(body, dict):
109+
return None
110+
tags = body.get("tags", None)
111+
if tags is not None and not isinstance(tags, list):
112+
raise ValueError(
113+
f"'tags' must be a list of strings, got: {type(tags).__name__}"
114+
)
115+
return tags
116+
117+
118+
@health_router.post("/release_memory_occupation")
119+
async def release_memory_occupation(request: Request):
120+
"""Release GPU memory occupation (sleep).
121+
122+
Offloads all model weights to CPU so the GPU is free for another
123+
workload (e.g. RL training). The server process stays alive;
124+
call ``/resume_memory_occupation`` to reload weights before the
125+
next generation.
126+
127+
Body (optional JSON):
128+
tags (list[str]): memory regions to release.
129+
Supported value: ``"weights"``. Omit to release all.
130+
"""
131+
try:
132+
body = await request.json()
133+
except Exception:
134+
body = {}
135+
try:
136+
tags = _extract_tags_from_body(body)
137+
except ValueError as exc:
138+
return ORJSONResponse({"success": False, "message": str(exc)}, status_code=422)
139+
140+
req = ReleaseMemoryOccupationReq(tags=tags)
141+
response = await async_scheduler_client.forward(req)
142+
if response.error:
143+
return ORJSONResponse(
144+
{"success": False, "message": response.error}, status_code=400
145+
)
146+
return ORJSONResponse(response.output)
147+
148+
149+
@health_router.post("/resume_memory_occupation")
150+
async def resume_memory_occupation(request: Request):
151+
"""Resume GPU memory occupation (wake up).
152+
153+
Loads model weights back onto the GPU so the server can serve
154+
generation requests again.
155+
156+
Body (optional JSON):
157+
tags (list[str]): memory regions to resume.
158+
Supported value: ``"weights"``. Omit to resume all.
159+
"""
160+
try:
161+
body = await request.json()
162+
except Exception:
163+
body = {}
164+
try:
165+
tags = _extract_tags_from_body(body)
166+
except ValueError as exc:
167+
return ORJSONResponse({"success": False, "message": str(exc)}, status_code=422)
168+
169+
req = ResumeMemoryOccupationReq(tags=tags)
170+
response = await async_scheduler_client.forward(req)
171+
if response.error:
172+
return ORJSONResponse(
173+
{"success": False, "message": response.error}, status_code=400
174+
)
175+
return ORJSONResponse(response.output)
176+
177+
104178
def make_serializable(obj):
105179
"""Recursively converts Tensors to None for JSON serialization."""
106180
if isinstance(obj, torch.Tensor):

python/sglang/multimodal_gen/runtime/entrypoints/openai/image_api.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ async def generations(
117117
true_cfg_scale=request.true_cfg_scale,
118118
negative_prompt=request.negative_prompt,
119119
enable_teacache=request.enable_teacache,
120+
rollout=request.rollout,
121+
rollout_sde_type=request.rollout_sde_type,
122+
rollout_noise_level=request.rollout_noise_level,
120123
output_compression=request.output_compression,
121124
output_quality=request.output_quality,
122125
)

python/sglang/multimodal_gen/runtime/entrypoints/openai/protocol.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ class ImageGenerationsRequest(BaseModel):
4646
output_quality: Optional[str] = "default"
4747
output_compression: Optional[int] = None
4848
enable_teacache: Optional[bool] = False
49+
rollout: Optional[bool] = False
50+
rollout_sde_type: Optional[str] = "sde"
51+
rollout_noise_level: Optional[float] = 0.7
4952
diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend
5053

5154

@@ -93,6 +96,9 @@ class VideoGenerationsRequest(BaseModel):
9396
output_quality: Optional[str] = "default"
9497
output_compression: Optional[int] = None
9598
output_path: Optional[str] = None
99+
rollout: Optional[bool] = False
100+
rollout_sde_type: Optional[str] = "sde"
101+
rollout_noise_level: Optional[float] = 0.7
96102
diffusers_kwargs: Optional[Dict[str, Any]] = None # kwargs for diffusers backend
97103

98104

python/sglang/multimodal_gen/runtime/entrypoints/openai/video_api.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def _build_video_sampling_params(request_id: str, request: VideoGenerationsReque
6969
guidance_scale_2=request.guidance_scale_2,
7070
negative_prompt=request.negative_prompt,
7171
enable_teacache=request.enable_teacache,
72+
rollout=request.rollout,
73+
rollout_sde_type=request.rollout_sde_type,
74+
rollout_noise_level=request.rollout_noise_level,
7275
output_path=request.output_path,
7376
output_compression=request.output_compression,
7477
output_quality=request.output_quality,
@@ -159,6 +162,9 @@ async def create_video(
159162
guidance_scale: Optional[float] = Form(None),
160163
num_inference_steps: Optional[int] = Form(None),
161164
enable_teacache: Optional[bool] = Form(False),
165+
rollout: Optional[bool] = Form(False),
166+
rollout_sde_type: Optional[str] = Form("sde"),
167+
rollout_noise_level: Optional[float] = Form(0.7),
162168
output_quality: Optional[str] = Form("default"),
163169
output_compression: Optional[int] = Form(None),
164170
extra_body: Optional[str] = Form(None),
@@ -212,6 +218,9 @@ async def create_video(
212218
negative_prompt=negative_prompt,
213219
num_inference_steps=num_inference_steps,
214220
enable_teacache=enable_teacache,
221+
rollout=rollout,
222+
rollout_sde_type=rollout_sde_type,
223+
rollout_noise_level=rollout_noise_level,
215224
output_compression=output_compression,
216225
output_quality=output_quality,
217226
**(

python/sglang/multimodal_gen/runtime/entrypoints/utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ class ShutdownReq:
6969
pass
7070

7171

72+
@dataclass
73+
class ReleaseMemoryOccupationReq:
74+
"""Request to release GPU memory (sleep). Offloads model weights to CPU."""
75+
76+
tags: Optional[List[str]] = None
77+
78+
79+
@dataclass
80+
class ResumeMemoryOccupationReq:
81+
"""Request to resume GPU memory (wake up). Loads model weights back to GPU."""
82+
83+
tags: Optional[List[str]] = None
84+
85+
7286
def format_lora_message(
7387
lora_nickname: Union[str, List[str]],
7488
target: Union[str, List[str]],
@@ -108,6 +122,7 @@ class GenerationResult:
108122
metrics: dict = field(default_factory=dict)
109123
trajectory_latents: Any = None
110124
trajectory_timesteps: Any = None
125+
trajectory_log_probs: Any = None
111126
trajectory_decoded: Any = None
112127
prompt_index: int = 0
113128
output_file_path: str | None = None

0 commit comments

Comments
 (0)