Skip to content

Commit 0d0eeae

Browse files
AR
1 parent 427472e commit 0d0eeae

File tree

6 files changed

+413
-227
lines changed

6 files changed

+413
-227
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,15 @@
9494
--transformer_type Cosmos-2.5-Transfer-General-2B \
9595
--transformer_ckpt_path $transformer_ckpt_path \
9696
--vae_type wan2.1 \
97-
--output_path converted/transfer/2b/general/depth \
97+
--output_path converted/transfer/2b/general/depth/pipeline \
9898
--save_pipeline
9999
100+
python scripts/convert_cosmos_to_diffusers.py \
101+
--transformer_type Cosmos-2.5-Transfer-General-2B \
102+
--transformer_ckpt_path $transformer_ckpt_path \
103+
--vae_type wan2.1 \
104+
--output_path converted/transfer/2b/general/depth/models
105+
100106
# edge
101107
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/edge/61f5694b-0ad5-4ecd-8ad7-c8545627d125_ema_bf16.pt
102108
@@ -120,18 +126,30 @@
120126
--transformer_type Cosmos-2.5-Transfer-General-2B \
121127
--transformer_ckpt_path $transformer_ckpt_path \
122128
--vae_type wan2.1 \
123-
--output_path converted/transfer/2b/general/blur \
129+
--output_path converted/transfer/2b/general/blur/pipeline \
124130
--save_pipeline
125131
132+
python scripts/convert_cosmos_to_diffusers.py \
133+
--transformer_type Cosmos-2.5-Transfer-General-2B \
134+
--transformer_ckpt_path $transformer_ckpt_path \
135+
--vae_type wan2.1 \
136+
--output_path converted/transfer/2b/general/blur/models
137+
126138
# seg
127139
transformer_ckpt_path=~/.cache/huggingface/hub/models--nvidia--Cosmos-Transfer2.5-2B/snapshots/eb5325b77d358944da58a690157dd2b8071bbf85/general/seg/5136ef49-6d8d-42e8-8abf-7dac722a304a_ema_bf16.pt
128140
129141
python scripts/convert_cosmos_to_diffusers.py \
130142
--transformer_type Cosmos-2.5-Transfer-General-2B \
131143
--transformer_ckpt_path $transformer_ckpt_path \
132144
--vae_type wan2.1 \
133-
--output_path converted/transfer/2b/general/seg \
145+
--output_path converted/transfer/2b/general/seg/pipeline \
134146
--save_pipeline
147+
148+
python scripts/convert_cosmos_to_diffusers.py \
149+
--transformer_type Cosmos-2.5-Transfer-General-2B \
150+
--transformer_ckpt_path $transformer_ckpt_path \
151+
--vae_type wan2.1 \
152+
--output_path converted/transfer/2b/general/seg/models
135153
```
136154
"""
137155

src/diffusers/models/controlnets/controlnet_cosmos.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,12 @@ def forward(
191191
dim=1,
192192
)
193193

194-
control_hidden_states = torch.cat([control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1)
194+
if condition_mask is not None:
195+
control_hidden_states = torch.cat([control_hidden_states, condition_mask], dim=1)
196+
else:
197+
control_hidden_states = torch.cat(
198+
[control_hidden_states, torch.zeros_like(controls_latents[:, :1])], dim=1
199+
)
195200

196201
padding_mask_resized = transforms.functional.resize(
197202
padding_mask, list(control_hidden_states.shape[-2:]), interpolation=transforms.InterpolationMode.NEAREST

src/diffusers/pipelines/cosmos/pipeline_cosmos2_5_predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ def __call__(
847847
video = (video * 255).astype(np.uint8)
848848
video_batch = []
849849
for vid in video:
850-
vid = self.safety_checker.check_video_safety(vid)
850+
# vid = self.safety_checker.check_video_safety(vid)
851851
video_batch.append(vid)
852852
video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1
853853
video = torch.from_numpy(video).permute(0, 4, 1, 2, 3)

0 commit comments

Comments
 (0)