Skip to content

Commit 9059a52

Browse files
committed
update
1 parent 3e019f2 commit 9059a52

File tree

1 file changed

+24
-13
lines changed

1 file changed

+24
-13
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,24 @@ def __init__(self, *args, **kwargs):
5555
Examples:
5656
```python
5757
>>> import torch
58-
>>> from diffusers import CosmosTextToImagePipeline
58+
>>> from diffusers import Cosmos2VideoToWorldPipeline
59+
>>> from diffusers.utils import export_to_video, load_image
5960
60-
>>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Text2Image, nvidia/Cosmos-Predict2-14B-Text2Image
61-
>>> model_id = "nvidia/Cosmos-Predict2-2B-Text2Image"
62-
>>> pipe = CosmosTextToImagePipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
61+
>>> # Available checkpoints: nvidia/Cosmos-Predict2-2B-Video2World, nvidia/Cosmos-Predict2-14B-Video2World
62+
>>> model_id = "nvidia/Cosmos-Predict2-2B-Video2World"
63+
>>> pipe = Cosmos2VideoToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16)
6364
>>> pipe.to("cuda")
6465
6566
>>> prompt = "A close-up shot captures a vibrant yellow scrubber vigorously working on a grimy plate, its bristles moving in circular motions to lift stubborn grease and food residue. The dish, once covered in remnants of a hearty meal, gradually reveals its original glossy surface. Suds form and bubble around the scrubber, creating a satisfying visual of cleanliness in progress. The sound of scrubbing fills the air, accompanied by the gentle clinking of the dish against the sink. As the scrubber continues its task, the dish transforms, gleaming under the bright kitchen lights, symbolizing the triumph of cleanliness over mess."
6667
>>> negative_prompt = "The video captures a series of frames showing ugly scenes, static with no motion, motion blur, over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and flickering. Overall, the video is of poor quality."
67-
68-
>>> output = pipe(
69-
... prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
70-
... ).images[0]
71-
>>> output.save("output.png")
68+
>>> image = load_image(
69+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/yellow-scrubber.png"
70+
... )
71+
72+
>>> video = pipe(
73+
... image=image, prompt=prompt, negative_prompt=negative_prompt, generator=torch.Generator().manual_seed(1)
74+
... ).frames[0]
75+
>>> export_to_video(video, "output.mp4", fps=16)
7276
```
7377
"""
7478

@@ -485,12 +489,15 @@ def __call__(
485489
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
486490
max_sequence_length: int = 512,
487491
sigma_conditioning: float = 0.0001,
488-
drop_unconditional: bool = False,
489492
):
490493
r"""
491494
The call function to the pipeline for generation.
492495
493496
Args:
497+
image (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, *optional*):
498+
The image to be used as a conditioning input for the video generation.
499+
video (`List[PIL.Image.Image]`, `np.ndarray`, `torch.Tensor`, *optional*):
500+
The video to be used as a conditioning input for the video generation.
494501
prompt (`str` or `List[str]`, *optional*):
495502
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
496503
instead.
@@ -538,6 +545,12 @@ def __call__(
538545
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
539546
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
540547
`._callback_tensor_inputs` attribute of your pipeline class.
548+
max_sequence_length (`int`, defaults to `512`):
549+
The maximum number of tokens in the prompt. If the prompt exceeds this length, it will be truncated. If
550+
the prompt is shorter than this length, it will be padded.
551+
sigma_conditioning (`float`, defaults to `0.0001`):
552+
The sigma value used for scaling conditioning latents. Ideally, it should not be changed or should be
553+
set to a small value close to zero.
541554
542555
Examples:
543556
@@ -634,9 +647,7 @@ def __call__(
634647
cond_mask = cond_mask.to(transformer_dtype)
635648
if self.do_classifier_free_guidance:
636649
uncond_mask = uncond_mask.to(transformer_dtype)
637-
unconditioning_latents = (
638-
torch.zeros_like(conditioning_latents) if drop_unconditional else conditioning_latents
639-
)
650+
unconditioning_latents = conditioning_latents
640651

641652
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
642653
sigma_conditioning = torch.full((batch_size,), sigma_conditioning, dtype=torch.float32, device=device)

0 commit comments

Comments
 (0)