-
Couldn't load subscription status.
- Fork 6.5k
Description
FluxPipeline has utilities that give us img_ids and txt_ids:
| def _prepare_latent_image_ids(batch_size, height, width, device, dtype): |
| text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) |
As such these are not created inside the transformer class.
Whereas in HiDream, we have something different.
text_ids are created inside the transformer class:
| txt_ids = torch.zeros( |
img_ids are overwritten:
https://github.com/huggingface/diffusers/blob/ce1063acfa0cbc2168a7e9dddd4282ab8013b810/src/diffusers/models/transformers/transformer_hidream_image.py#L771C13-L771C20 (probably intentional because it's conditioned)
Then the entire computation
diffusers/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py
Lines 726 to 744 in ce1063a
| if latents.shape[-2] != latents.shape[-1]: | |
| B, C, H, W = latents.shape | |
| pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size | |
| img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1) | |
| img_ids = torch.zeros(pH, pW, 3) | |
| img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None] | |
| img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :] | |
| img_ids = img_ids.reshape(pH * pW, -1) | |
| img_ids_pad = torch.zeros(self.transformer.max_seq, 3) | |
| img_ids_pad[: pH * pW, :] = img_ids | |
| img_sizes = img_sizes.unsqueeze(0).to(latents.device) | |
| img_ids = img_ids_pad.unsqueeze(0).to(latents.device) | |
| if self.do_classifier_free_guidance: | |
| img_sizes = img_sizes.repeat(2 * B, 1) | |
| img_ids = img_ids.repeat(2 * B, 1, 1) | |
| else: | |
| img_sizes = img_ids = None |
happens inside the pipeline __call__(). Maybe this could take place inside a method similar to the FluxPipeline?
In general, these could be standardized a bit.
Cc: @yiyixuxu @a-r-r-o-w