Skip to content

Commit 5b4e36f

Browse files
committed
reference delta
1 parent 9907fcc commit 5b4e36f

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

src/train/model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def step(self, batch):
401401

402402
# Apply position delta to reference image IDs
403403
delta = reference_deltas[0]
404+
delta[0] = delta[0] + 1
404405

405406

406407
# Combine input and reference images
@@ -412,6 +413,7 @@ def step(self, batch):
412413
(1, reference_height // self.qwen_image_edit_pipe.vae_scale_factor // 2, reference_width // self.qwen_image_edit_pipe.vae_scale_factor // 2, *delta),
413414
]
414415
]
416+
print(img_shapes)
415417

416418
u = compute_density_for_timestep_sampling(
417419
weighting_scheme="none",
@@ -430,19 +432,13 @@ def step(self, batch):
430432
# print(x_t.shape, condition.shape, sigmas)
431433
latent_model_input = torch.cat([x_t, condition], dim=1)
432434

433-
# Prepare guidance
434-
guidance = (
435-
torch.ones_like(t).to(self.device)
436-
if self.transformer.config.guidance_embeds
437-
else None
438-
)
439435

440436
# Forward pass
441437
pred = forward(
442438
self.transformer,
443439
hidden_states=latent_model_input,
444440
timestep=timesteps/1000,
445-
guidance=guidance,
441+
guidance=None,
446442
encoder_hidden_states_mask=prompt_embeds_mask,
447443
encoder_hidden_states=prompt_embeds,
448444
img_shapes=img_shapes,

0 commit comments

Comments
 (0)