Skip to content

Commit 08c19d0

Browse files
author
Vikram Voleti
committed
Combined sv4dv2 and sv4dv2_8views sampling scripts
1 parent 5f5aaf9 commit 08c19d0

File tree

5 files changed

+245
-293
lines changed

5 files changed

+245
-293
lines changed

README.md

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,19 @@
55
## News
66

77

8-
**April 3, 2025**
8+
**April 4, 2025**
99
- We are releasing **[Stable Video 4D 2.0 (SV4D 2.0)](https://huggingface.co/stabilityai/sv4d2.0)**, an enhanced video-to-4D diffusion model for high-fidelity novel-view video synthesis and 4D asset generation. For research purposes:
1010
- **SV4D 2.0** was trained to generate 48 frames (12 video frames x 4 camera views) at 576x576 resolution, given a 12-frame input video of the same size, ideally consisting of white-background images of a moving object.
1111
- Compared to our previous 4D model [SV4D](https://huggingface.co/stabilityai/sv4d), **SV4D 2.0** can generate videos with higher fidelity, sharper details during motion, and better spatio-temporal consistency. It also generalizes much better to real-world videos. Moreover, it does not rely on refernce multi-view of the first frame generated by SV3D, making it more robust to self-occlusions.
1212
- To generate longer novel-view videos, we autoregressively generate 12 frames at a time and use the previous generation as conditioning views for the remaining frames.
1313
- Please check our [project page](https://sv4d20.github.io), [arxiv paper](https://arxiv.org/pdf/2503.16396) and [video summary](https://www.youtube.com/watch?v=dtqj-s50ynU) for more details.
1414

15-
**QUICKSTART** :
16-
- `python scripts/sampling/simple_video_sample_4d2.py --input_path assets/sv4d_videos/camel.gif --output_folder outputs/sv4d2`
17-
- We also train a 8-view model that generates 5 frames x 8 views at a time (same as SV4D). For example, run `python scripts/sampling/simple_video_sample_4d2_8views.py --input_path assets/sv4d_videos/chest.gif --output_folder outputs/sv4d2_8views`
15+
**QUICKSTART** :
16+
- `python scripts/sampling/simple_video_sample_4d2.py --input_path assets/sv4d_videos/camel.gif --output_folder outputs` (after downloading [sv4d2.safetensors](https://huggingface.co/stabilityai/sv4d2.0) from HuggingFace into `checkpoints/`)
1817

1918
To run **SV4D 2.0** on a single input video of 21 frames:
20-
- Download SV4D 2.0 models (`sv4d2.safetensors` and `sv4d2_8views.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d2.0) to `checkpoints/`
21-
- Run `python scripts/sampling/simple_video_sample_4d2.py --input_path <path/to/video>`
19+
- Download SV4D 2.0 model (`sv4d2.safetensors`) from [here](https://huggingface.co/stabilityai/sv4d2.0) to `checkpoints/`: `huggingface-cli download stabilityai/sv4d2.0 sv4d2.safetensors --local-dir checkpoints`
20+
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --input_path <path/to/video>`
2221
- `input_path` : The input video `<path/to/video>` can be
2322
- a single video file in `gif` or `mp4` format, such as `assets/sv4d_videos/camel.gif`, or
2423
- a folder containing images of video frames in `.jpg`, `.jpeg`, or `.png` format, or
@@ -28,6 +27,21 @@ To run **SV4D 2.0** on a single input video of 21 frames:
2827
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Clipdrop](https://clipdrop.co/) or [SAM2](https://github.com/facebookresearch/segment-anything-2) before running SV4D.
2928
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--encoding_t=1` (of frames encoded at a time) and `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
3029

30+
Notes:
31+
- We also train a 8-view model that generates 5 frames x 8 views at a time (same as SV4D).
32+
- Download the model from huggingface: `huggingface-cli download stabilityai/sv4d2.0 sv4d2_8views.safetensors --local-dir checkpoints`
33+
- Run inference: `python scripts/sampling/simple_video_sample_4d2.py --model_path checkpoints/sv4d2_8views.safetensors --input_path assets/sv4d_videos/chest.gif --output_folder outputs`
34+
- The 5x8 model takes 5 frames of input at a time. But the inference scripts for both model take 21-frame video as input by default (same as SV3D and SV4D), we run the model autoregressively until we generate 21 frames.
35+
- Install dependencies before running:
36+
```
37+
python3.10 -m venv .generativemodels
38+
source .generativemodels/bin/activate
39+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 # check CUDA version
40+
pip3 install -r requirements/pt2.txt
41+
pip3 install .
42+
pip3 install -e git+https://github.com/Stability-AI/datapipelines.git@main#egg=sdata
43+
```
44+
3145
![tile](assets/sv4d2.gif)
3246

3347

@@ -190,6 +204,7 @@ This is assuming you have navigated to the `generative-models` root after clonin
190204
# install required packages from pypi
191205
python3 -m venv .pt2
192206
source .pt2/bin/activate
207+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
193208
pip3 install -r requirements/pt2.txt
194209
```
195210

requirements/pt2.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ einops>=0.6.1
55
fairscale>=0.4.13
66
fire>=0.5.0
77
fsspec>=2023.6.0
8+
imageio[ffmpeg]
9+
imageio[pyav]
810
invisible-watermark>=0.2.0
911
kornia==0.6.9
1012
matplotlib>=3.7.2
1113
natsort>=8.4.0
1214
ninja>=1.11.1
13-
numpy>=1.24.4
15+
numpy==2.1
1416
omegaconf>=2.3.0
17+
onnxruntime
1518
open-clip-torch>=2.20.0
1619
opencv-python==4.6.0.66
1720
pandas>=2.0.3

scripts/demo/sv4d_helpers.py

Lines changed: 112 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from omegaconf import ListConfig, OmegaConf
1414
from PIL import Image, ImageSequence
1515
from rembg import remove
16-
from torch import autocast
17-
from torchvision.transforms import ToTensor
18-
1916
from scripts.util.detection.nsfw_and_watermark_dectection import DeepFloydDataFiltering
2017
from sgm.modules.autoencoding.temporal_ae import VideoDecoder
2118
from sgm.modules.diffusionmodules.guiders import (
@@ -34,6 +31,8 @@
3431
LinearMultistepSampler,
3532
)
3633
from sgm.util import default, instantiate_from_config
34+
from torch import autocast
35+
from torchvision.transforms import ToTensor
3736

3837

3938
def load_module_gpu(model):
@@ -166,14 +165,14 @@ def read_video(
166165

167166

168167
def preprocess_video(
169-
input_path,
170-
remove_bg=False,
171-
n_frames=21,
172-
W=576,
173-
H=576,
174-
output_folder=None,
168+
input_path,
169+
remove_bg=False,
170+
n_frames=21,
171+
W=576,
172+
H=576,
173+
output_folder=None,
175174
image_frame_ratio=0.917,
176-
base_count=0
175+
base_count=0,
177176
):
178177
print(f"preprocess {input_path}")
179178
if output_folder is None:
@@ -208,7 +207,9 @@ def preprocess_video(
208207
images = [Image.open(img_path) for img_path in all_img_paths]
209208

210209
if len(images) != n_frames:
211-
raise ValueError(f"Input video contains {len(images)} frames, fewer than {n_frames} frames.")
210+
raise ValueError(
211+
f"Input video contains {len(images)} frames, fewer than {n_frames} frames."
212+
)
212213

213214
# Remove background
214215
for i, image in enumerate(images):
@@ -235,29 +236,36 @@ def preprocess_video(
235236
else:
236237
# assume the input image has white background
237238
ret, mask = cv2.threshold(
238-
(np.array(image).mean(-1) <= white_thresh).astype(np.uint8) * 255, 0, 255, cv2.THRESH_BINARY
239+
(np.array(image).mean(-1) <= white_thresh).astype(np.uint8) * 255,
240+
0,
241+
255,
242+
cv2.THRESH_BINARY,
239243
)
240-
244+
241245
x, y, w, h = cv2.boundingRect(mask)
242246
box_coord[0] = min(box_coord[0], x)
243247
box_coord[1] = min(box_coord[1], y)
244248
box_coord[2] = max(box_coord[2], x + w)
245249
box_coord[3] = max(box_coord[3], y + h)
246-
box_square = max(original_center[0] - box_coord[0], original_center[1] - box_coord[1])
250+
box_square = max(
251+
original_center[0] - box_coord[0], original_center[1] - box_coord[1]
252+
)
247253
box_square = max(box_square, box_coord[2] - original_center[0])
248254
box_square = max(box_square, box_coord[3] - original_center[1])
249-
x, y = max(0, original_center[0] - box_square), max(0, original_center[1] - box_square)
250-
w, h = min(image_arr.shape[0], 2 * box_square), min(image_arr.shape[1], 2 * box_square)
255+
x, y = max(0, original_center[0] - box_square), max(
256+
0, original_center[1] - box_square
257+
)
258+
w, h = min(image_arr.shape[0], 2 * box_square), min(
259+
image_arr.shape[1], 2 * box_square
260+
)
251261
box_size = box_square * 2
252262

253263
for image in images:
254264
if image.mode == "RGB":
255265
image = image.convert("RGBA")
256266
image_arr = np.array(image)
257267
side_len = (
258-
int(box_size / image_frame_ratio)
259-
if image_frame_ratio is not None
260-
else in_w
268+
int(box_size / image_frame_ratio) if image_frame_ratio is not None else in_w
261269
)
262270
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
263271
center = side_len // 2
@@ -273,9 +281,9 @@ def preprocess_video(
273281
rgba_arr = np.array(rgba) / 255.0
274282
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
275283
image = (rgb * 255).astype(np.uint8)
276-
284+
277285
images_v0.append(image)
278-
286+
279287
processed_file = os.path.join(output_folder, f"{base_count:06d}_process_input.mp4")
280288
imageio.mimwrite(processed_file, images_v0, fps=10)
281289
return processed_file
@@ -393,13 +401,17 @@ def denoiser(input, sigma, c):
393401
return samples
394402

395403

396-
def decode_latents(model, samples_z, img_matrix, frame_indices, view_indices, timesteps):
404+
def decode_latents(
405+
model, samples_z, img_matrix, frame_indices, view_indices, timesteps
406+
):
397407
load_module_gpu(model.first_stage_model)
398408
for t in frame_indices:
399409
for v in view_indices:
400-
if True: # t != 0 and v != 0:
410+
if True: # t != 0 and v != 0:
401411
if isinstance(model.first_stage_model.decoder, VideoDecoder):
402-
samples_x = model.decode_first_stage(samples_z[t, v][None], timesteps=timesteps)
412+
samples_x = model.decode_first_stage(
413+
samples_z[t, v][None], timesteps=timesteps
414+
)
403415
else:
404416
samples_x = model.decode_first_stage(samples_z[t, v][None])
405417
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
@@ -785,7 +797,19 @@ def run_img2vid(
785797
return samples
786798

787799

788-
def prepare_inputs_forward_backward(img_matrix, view_indices, frame_indices, v0, t0, t1, model, version_dict, seed, polars, azims):
800+
def prepare_inputs_forward_backward(
801+
img_matrix,
802+
view_indices,
803+
frame_indices,
804+
v0,
805+
t0,
806+
t1,
807+
model,
808+
version_dict,
809+
seed,
810+
polars,
811+
azims,
812+
):
789813
# forward sampling
790814
forward_frame_indices = frame_indices.copy()
791815
image = img_matrix[t0][v0]
@@ -801,7 +825,7 @@ def prepare_inputs_forward_backward(img_matrix, view_indices, frame_indices, v0,
801825
cond_motion,
802826
cond_view,
803827
)
804-
828+
805829
# backward sampling
806830
backward_frame_indices = frame_indices[::-1].copy()
807831
image = img_matrix[t1][v0]
@@ -817,10 +841,25 @@ def prepare_inputs_forward_backward(img_matrix, view_indices, frame_indices, v0,
817841
cond_motion,
818842
cond_view,
819843
)
820-
return forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices
844+
return (
845+
forward_inputs,
846+
forward_frame_indices,
847+
backward_inputs,
848+
backward_frame_indices,
849+
)
821850

822851

823-
def prepare_inputs(frame_indices, img_matrix, v0, view_indices, model, version_dict, seed, polars, azims):
852+
def prepare_inputs(
853+
frame_indices,
854+
img_matrix,
855+
v0,
856+
view_indices,
857+
model,
858+
version_dict,
859+
seed,
860+
polars,
861+
azims,
862+
):
824863
load_module_gpu(model.conditioner)
825864
# forward sampling
826865
forward_frame_indices = frame_indices.copy()
@@ -829,35 +868,40 @@ def prepare_inputs(frame_indices, img_matrix, v0, view_indices, model, version_d
829868
cond_motion = torch.cat([img_matrix[t][v0] for t in forward_frame_indices], 0)
830869
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
831870
forward_inputs = prepare_sampling(
832-
version_dict,
833-
model,
834-
image,
835-
seed,
836-
polars,
837-
azims,
838-
cond_motion,
839-
cond_view,
840-
)
841-
871+
version_dict,
872+
model,
873+
image,
874+
seed,
875+
polars,
876+
azims,
877+
cond_motion,
878+
cond_view,
879+
)
880+
842881
# backward sampling
843882
backward_frame_indices = frame_indices[::-1].copy()
844883
t0 = backward_frame_indices[0]
845884
image = img_matrix[t0][v0]
846885
cond_motion = torch.cat([img_matrix[t][v0] for t in backward_frame_indices], 0)
847886
cond_view = torch.cat([img_matrix[t0][v] for v in view_indices], 0)
848887
backward_inputs = prepare_sampling(
849-
version_dict,
850-
model,
851-
image,
852-
seed,
853-
polars,
854-
azims,
855-
cond_motion,
856-
cond_view,
857-
)
888+
version_dict,
889+
model,
890+
image,
891+
seed,
892+
polars,
893+
azims,
894+
cond_motion,
895+
cond_view,
896+
)
858897

859898
unload_module_gpu(model.conditioner)
860-
return forward_inputs, forward_frame_indices, backward_inputs, backward_frame_indices
899+
return (
900+
forward_inputs,
901+
forward_frame_indices,
902+
backward_inputs,
903+
backward_frame_indices,
904+
)
861905

862906

863907
def do_sample(
@@ -913,7 +957,7 @@ def do_sample(
913957
lambda y: y[k][: math.prod(num_samples)].to("cuda"), (c, uc)
914958
)
915959

916-
if value_dict['image_only_indicator'] == 0:
960+
if value_dict["image_only_indicator"] == 0:
917961
c["cond_view"] *= 0
918962
uc["cond_view"] *= 0
919963

@@ -932,9 +976,12 @@ def do_sample(
932976
SpatiotemporalPredictionGuider,
933977
),
934978
):
935-
additional_model_inputs[k] = torch.zeros(
936-
num_samples[0] * 2, num_samples[1]
937-
).to("cuda") + value_dict['image_only_indicator']
979+
additional_model_inputs[k] = (
980+
torch.zeros(num_samples[0] * 2, num_samples[1]).to(
981+
"cuda"
982+
)
983+
+ value_dict["image_only_indicator"]
984+
)
938985
else:
939986
additional_model_inputs[k] = torch.zeros(num_samples).to(
940987
"cuda"
@@ -949,6 +996,7 @@ def denoiser(input, sigma, c):
949996
return model.denoiser(
950997
model.model, input, sigma, c, **additional_model_inputs
951998
)
999+
9521000
load_module_gpu(model.model)
9531001
load_module_gpu(model.denoiser)
9541002
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
@@ -1034,9 +1082,12 @@ def prepare_sampling_(
10341082
SpatiotemporalPredictionGuider,
10351083
),
10361084
):
1037-
additional_model_inputs[k] = torch.zeros(
1038-
num_samples[0] * 2, num_samples[1]
1039-
).to("cuda") + value_dict['image_only_indicator']
1085+
additional_model_inputs[k] = (
1086+
torch.zeros(num_samples[0] * 2, num_samples[1]).to(
1087+
"cuda"
1088+
)
1089+
+ value_dict["image_only_indicator"]
1090+
)
10401091
else:
10411092
additional_model_inputs[k] = torch.zeros(num_samples).to(
10421093
"cuda"
@@ -1047,7 +1098,9 @@ def prepare_sampling_(
10471098
return c, uc, additional_model_inputs
10481099

10491100

1050-
def do_sample_per_step(model, sampler, noisy_latents, c, uc, step, additional_model_inputs):
1101+
def do_sample_per_step(
1102+
model, sampler, noisy_latents, c, uc, step, additional_model_inputs
1103+
):
10511104
precision_scope = autocast
10521105
with torch.no_grad():
10531106
with precision_scope("cuda"):
@@ -1337,6 +1390,7 @@ def load_model(
13371390
num_frames: int,
13381391
num_steps: int,
13391392
verbose: bool = False,
1393+
ckpt_path: str = None,
13401394
):
13411395
config = OmegaConf.load(config)
13421396
if device == "cuda":
@@ -1349,6 +1403,8 @@ def load_model(
13491403
config.model.params.sampler_config.params.guider_config.params.num_frames = (
13501404
num_frames
13511405
)
1406+
if ckpt_path is not None:
1407+
config.model.params.ckpt_path = ckpt_path
13521408
if device == "cuda":
13531409
with torch.device(device):
13541410
model = instantiate_from_config(config.model).to(device).eval()

0 commit comments

Comments
 (0)