diff --git a/invokeai/app/invocations/denoise_latents_extensions.py b/invokeai/app/invocations/denoise_latents_extensions.py new file mode 100644 index 00000000000..7d454cc7c65 --- /dev/null +++ b/invokeai/app/invocations/denoise_latents_extensions.py @@ -0,0 +1,262 @@ +from abc import ABC, abstractmethod +import einops +import torch +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from torchvision.transforms.functional import resize as tv_resize +from dataclasses import dataclass +from invokeai.app.invocations.t2i_adapter import T2IAdapterField +from invokeai.app.invocations.ip_adapter import IPAdapterField +from .controlnet_image_processors import ControlField +from invokeai.invocation_api import ( + InvocationContext, + ConditioningField, + LatentsField, + UNetField, +) + +@dataclass +class DenoiseLatentsData: + context: InvocationContext + positive_conditioning: ConditioningField + negative_conditioning: ConditioningField + noise: LatentsField | None + latents: LatentsField | None + steps: int + cfg_scale: float + denoising_start: float + denoising_end: float + scheduler: SchedulerMixin + unet: UNetField + unet_model: UNet2DConditionModel + control: ControlField | list[ControlField] | None + ip_adapter: IPAdapterField | list[IPAdapterField] | None + t2i_adapter: T2IAdapterField | list[T2IAdapterField] | None + seed: int + + def copy(self): + return DenoiseLatentsData( + context=self.context, + positive_conditioning=self.positive_conditioning, + negative_conditioning=self.negative_conditioning, + noise=self.noise, + latents=self.latents, + steps=self.steps, + cfg_scale=self.cfg_scale, + denoising_start=self.denoising_start, + denoising_end=self.denoising_end, + scheduler=self.scheduler, + unet=self.unet, + unet_model=self.unet_model, + control=self.control, + ip_adapter=self.ip_adapter, + t2i_adapter=self.t2i_adapter, + seed=self.seed + ) + + +class DenoiseExtensionSD12X(ABC): + + def __init__(self, denoise_latents_data: DenoiseLatentsData, priority: int, extension_kwargs: dict): + """ + Do not modify: Use __post_init__ to handle extension-specific parameters + During injection calls, extensions will be called in order of self.priority (ascending) + self.denoise_latents_data exists in case you need to access the data from calling node + """ + self.denoise_latents_data = denoise_latents_data + self.priority = priority + self.__post_init__(**extension_kwargs) + + def __post_init__(self): + """ + Called after the object is created. + Override this method to perform additional initialization steps. + """ + pass + + def list_modifies(self) -> list[str]: + """ + A list of all the modify methods that this extension provides. + e.g. ['modify_latents_before_scaling', 'modify_latents_before_noise_prediction'] + The injection names must match the method names in this class. + """ + return [] + + def list_provides(self) -> list[str]: + """ + A list of all the provide methods that this extension provides. + e.g. ['provide_latents', 'provide_noise'] + The injection names must match the method names in this class. + """ + return [] + + def list_swaps(self) -> list[str]: + """ + A list of all the swap methods that this extension provides. + e.g. ['swap_latents', 'swap_noise'] + The injection names must match the method names in this class. + """ + return [] + + def modify_latents_before_scaling(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Samplers apply scalar multiplication to the latents before predicting noise. + This method allows you to modify the latents before this scaling is applied each step. + Useful if the modifications need to align with image or color in the normal latent space. + """ + return latents + + def modify_latents_before_noise_prediction(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Last chance to modify latents before noise is predicted. + Additional channels for inpaint models are added here. + """ + return latents + + def modify_result_before_callback(self, step_output, t) -> torch.Tensor: + """ + step_output.prev_sample is the current latents that will be used in the next step. + if step_output.pred_original_sample is provided/modified, it will be used in the image preview for the user. + """ + return step_output + + def modify_latents_after_denoising(self, latents: torch.Tensor) -> torch.Tensor: + """ + Final result of the latents after all steps are complete. + """ + return latents + + +class AddsMaskGuidance(DenoiseExtensionSD12X): + + def __post_init__(self, mask_name: str, masked_latents_name: str | None, gradient_mask: bool): + """Align internal data and create noise if necessary""" + context = self.denoise_latents_data.context + if self.denoise_latents_data.latents is not None: + self.orig_latents = context.tensors.load(self.denoise_latents_data.latents.latents_name) + else: + raise ValueError("Latents input is required for the denoise mask extension") + if self.denoise_latents_data.noise is not None: + self.noise = context.tensors.load(self.denoise_latents_data.noise.latents_name) + else: + self.noise = torch.randn( + self.orig_latents.shape, + dtype=torch.float32, + device="cpu", + generator=torch.Generator(device="cpu").manual_seed(self.denoise_latents_data.seed), + ).to(device=self.orig_latents.device, dtype=self.orig_latents.dtype) + + self.mask: torch.Tensor = context.tensors.load(mask_name) + self.masked_latents = None if masked_latents_name is None else context.tensors.load(masked_latents_name) + self.scheduler: SchedulerMixin = self.denoise_latents_data.scheduler + self.gradient_mask: bool = gradient_mask + self.unet_type: str = self.denoise_latents_data.unet.unet.base + self.inpaint_model = self.denoise_latents_data.unet_model.conv_in.in_channels == 9 + self.seed: int = self.denoise_latents_data.seed + + self.mask = tv_resize(self.mask, list(self.orig_latents.shape[-2:])) + self.mask = self.mask.to(device=self.orig_latents.device, dtype=self.orig_latents.dtype) + + def list_injections(self) -> list[str]: + return [ + "modify_latents_before_scaling", + "modify_latents_before_noise_prediction", + "modify_result_before_callback", + "modify_latents_after_denoising" + ] + + def mask_from_timestep(self, t: torch.Tensor) -> torch.Tensor: + """Create a mask based on the current timestep""" + if self.inpaint_model: + mask_bool = self.mask < 1 + floored_mask = torch.where(mask_bool, 0, 1) + return floored_mask + elif self.gradient_mask: + threshhold = (t.item()) / self.scheduler.config.num_train_timesteps + mask_bool = self.mask < 1 - threshhold + timestep_mask = torch.where(mask_bool, 0, 1) + return timestep_mask.to(device=self.mask.device) + else: + return self.mask.clone() + + def modify_latents_before_scaling(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Replace unmasked region with original latents. Called before the scheduler scales the latent values.""" + if self.inpaint_model: + return latents # skip this stage + + #expand to match batch size if necessary + batch_size = latents.size(0) + mask = self.mask_from_timestep(t).to(device=latents.device, dtype=latents.dtype) + mask = einops.repeat(mask, "b c h w -> (repeat b) c h w", repeat=batch_size) + if t.dim() == 0: + t = einops.repeat(t, "-> batch", batch=batch_size) + + # create noised version of the original latents + noised_latents = self.scheduler.add_noise(self.orig_latents, self.noise, t) + noised_latents = einops.repeat(noised_latents, "b c h w -> (repeat b) c h w", repeat=batch_size).to(device=latents.device, dtype=latents.dtype) + mask = self.mask_from_timestep(t).to(device=latents.device, dtype=latents.dtype) + masked_input = torch.lerp(latents, noised_latents, mask) + return masked_input + + def shrink_mask(self, mask: torch.Tensor, n_operations: int) -> torch.Tensor: + kernel = torch.ones(1, 1, 3, 3).to(device=mask.device, dtype=mask.dtype) + for _ in range(n_operations): + mask = torch.nn.functional.conv2d(mask, kernel, padding=1).clamp(0, 1) + return mask + + def modify_latents_before_noise_prediction(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Expand latents with information needed by inpaint model""" + if not self.inpaint_model: + return latents # skip this stage + + mask = self.mask_from_timestep(t).to(device=latents.device, dtype=latents.dtype) + if self.masked_latents is None: + #latent values for a black region after VAE encode + if self.unet_type == "sd-1": + latent_zeros = [0.78857421875, -0.638671875, 0.576171875, 0.12213134765625] + elif self.unet_type == "sd-2": + latent_zeros = [0.7890625, -0.638671875, 0.576171875, 0.12213134765625] + print("WARNING: SD-2 Inpaint Models are not yet supported") + elif self.unet_type == "sdxl": + latent_zeros = [-0.578125, 0.501953125, 0.59326171875, -0.393798828125] + else: + raise ValueError(f"Unet type {self.unet_type} not supported as an inpaint model. Where did you get this?") + + # replace masked region with specified values + mask_values = torch.tensor(latent_zeros).view(1, 4, 1, 1).expand_as(latents).to(device=latents.device, dtype=latents.dtype) + small_mask = self.shrink_mask(mask, 1) #make the synthetic mask fill in the masked_latents smaller than the mask channel + self.masked_latents = torch.where(small_mask == 0, mask_values, self.orig_latents) + + masked_latents = self.scheduler.scale_model_input(self.masked_latents,t) + masked_latents = einops.repeat(masked_latents, "b c h w -> (repeat b) c h w", repeat=latents.size(0)) + mask = einops.repeat(mask, "b c h w -> (repeat b) c h w", repeat=latents.size(0)) + model_input = torch.cat([latents, 1 - mask, masked_latents], dim=1).to(dtype=latents.dtype, device=latents.device) + return model_input + + def modify_result_before_callback(self, step_output, t) -> torch.Tensor: + """Fix preview images to show the original image in the unmasked region""" + if hasattr(step_output, "denoised"): #LCM Sampler + prediction = step_output.denoised + elif hasattr(step_output, "pred_original_sample"): #Samplers with final predictions + prediction = step_output.pred_original_sample + else: #all other samplers (no prediction available) + prediction = step_output.prev_sample + + mask = self.mask_from_timestep(t) + mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=prediction.size(0)) + step_output.pred_original_sample = torch.lerp(prediction, self.orig_latents.to(dtype=prediction.dtype), mask.to(dtype=prediction.dtype)) + + return step_output + + def modify_latents_after_denoising(self, latents: torch.Tensor) -> torch.Tensor: + """Apply original unmasked to denoised latents""" + if self.inpaint_model: + if self.masked_latents is None: + mask = self.shrink_mask(self.mask, 1) + else: + return latents + else: + mask = self.mask_from_timestep(torch.Tensor([0])) + mask = einops.repeat(mask, "b c h w -> (repeat b) c h w", repeat=latents.size(0)) + latents = torch.lerp(latents, self.orig_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype)).to(device=latents.device) + return latents diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ce63d568c62..77531aec702 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -854,8 +854,6 @@ def invoke(self, context: InvocationContext) -> LatentsOutput: if seed is None: seed = 0 - mask, masked_latents, gradient_mask = self.prep_inpaint_mask(context, latents) - # TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets, # below. Investigate whether this is appropriate. t2i_adapter_data = self.run_t2i_adapters( @@ -893,10 +891,6 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: latents = latents.to(device=unet.device, dtype=unet.dtype) if noise is not None: noise = noise.to(device=unet.device, dtype=unet.dtype) - if mask is not None: - mask = mask.to(device=unet.device, dtype=unet.dtype) - if masked_latents is not None: - masked_latents = masked_latents.to(device=unet.device, dtype=unet.dtype) scheduler = get_scheduler( context=context, @@ -945,9 +939,7 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]: init_timestep=init_timestep, noise=noise, seed=seed, - mask=mask, - masked_latents=masked_latents, - gradient_mask=gradient_mask, + additional_guidance=additional_guidance, num_inference_steps=num_inference_steps, scheduler_step_kwargs=scheduler_step_kwargs, conditioning_data=conditioning_data, diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index b4d1b3381c7..c65a643d664 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -18,6 +18,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils.import_utils import is_xformers_available from pydantic import Field +from torchvision.transforms.functional import resize as tv_resize from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from invokeai.app.services.config.config_default import get_config @@ -81,36 +82,127 @@ def are_like_tensors(a: torch.Tensor, b: object) -> bool: @dataclass class AddsMaskGuidance: - mask: torch.FloatTensor - mask_latents: torch.FloatTensor + orig_latents: torch.Tensor + mask: torch.Tensor # 0 is masked, 1 is unmasked + masked_latents: torch.Tensor | None scheduler: SchedulerMixin - noise: torch.Tensor + noise: torch.Tensor | None gradient_mask: bool + unet_type: str + inpaint_model: bool + seed: int + + def __post_init__(self): + """Align internal data and create noise if necessary""" + self.mask = tv_resize(self.mask, self.orig_latents.shape[-2:]) + self.mask = self.mask.to(device=self.orig_latents.device, dtype=self.orig_latents.dtype) + if self.noise is None: + self.noise = torch.randn( + self.orig_latents.shape, + dtype=torch.float32, + device="cpu", + generator=torch.Generator(device="cpu").manual_seed(self.seed), + ).to(device=self.orig_latents.device, dtype=self.orig_latents.dtype) + + def mask_from_timestep(self, t: torch.Tensor) -> torch.Tensor: + """Create a mask based on the current timestep""" + if self.inpaint_model: + mask_bool = self.mask < 1 + floored_mask = torch.where(mask_bool, 0, 1) + return floored_mask + elif self.gradient_mask: + threshhold = (t.item()) / self.scheduler.config.num_train_timesteps + mask_bool = self.mask < 1 - threshhold + timestep_mask = torch.where(mask_bool, 0, 1) + return timestep_mask.to(device=self.mask.device) + else: + print("normal mask used") + return self.mask.clone() - def __call__(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: - return self.apply_mask(latents, t) + def modify_latents_before_scaling(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Replace unmasked region with original latents. Called before the scheduler scales the latent values.""" + if self.inpaint_model: + return latents # skip this stage - def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor: + #expand to match batch size if necessary batch_size = latents.size(0) - mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=batch_size) + mask = self.mask_from_timestep(t).to(device=latents.device, dtype=latents.dtype) + mask = einops.repeat(mask, "b c h w -> (repeat b) c h w", repeat=batch_size) if t.dim() == 0: - # some schedulers expect t to be one-dimensional. - # TODO: file diffusers bug about inconsistency? t = einops.repeat(t, "-> batch", batch=batch_size) - # Noise shouldn't be re-randomized between steps here. The multistep schedulers - # get very confused about what is happening from step to step when we do that. - mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t) - # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already? - # mask_latents = self.scheduler.scale_model_input(mask_latents, t) - mask_latents = einops.repeat(mask_latents, "b c h w -> (repeat b) c h w", repeat=batch_size) - if self.gradient_mask: - threshhold = (t.item()) / self.scheduler.config.num_train_timesteps - mask_bool = mask > threshhold # I don't know when mask got inverted, but it did - masked_input = torch.where(mask_bool, latents, mask_latents) - else: - masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype)) + + # create noised version of the original latents + noised_latents = self.scheduler.add_noise(self.orig_latents, self.noise, t) + noised_latents = einops.repeat(noised_latents, "b c h w -> (repeat b) c h w", repeat=batch_size).to(device=latents.device, dtype=latents.dtype) + mask = self.mask_from_timestep(t).to(device=latents.device, dtype=latents.dtype) + masked_input = torch.lerp(latents, noised_latents, mask) return masked_input + def shrink_mask(self, mask: torch.Tensor, n_operations: int) -> torch.Tensor: + kernel = torch.ones(1, 1, 3, 3).to(device=mask.device, dtype=mask.dtype) + for _ in range(n_operations): + mask = torch.nn.functional.conv2d(mask, kernel, padding=1).clamp(0, 1) + return mask + + def modify_latents_before_noise_prediction(self, latents: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """Expand latents with information needed by inpaint model""" + if not self.inpaint_model: + return latents # skip this stage + + mask = self.mask_from_timestep(t).to(device=latents.device, dtype=latents.dtype) + if self.masked_latents is None: + #latent values for a black region after VAE encode + if self.unet_type == "sd-1": + latent_zeros = [0.78857421875, -0.638671875, 0.576171875, 0.12213134765625] + elif self.unet_type == "sd-2": + latent_zeros = [0.7890625, -0.638671875, 0.576171875, 0.12213134765625] + print("WARNING: SD-2 Inpaint Models are not yet supported") + elif self.unet_type == "sdxl": + latent_zeros = [-0.578125, 0.501953125, 0.59326171875, -0.393798828125] + else: + raise ValueError(f"Unet type {self.unet_type} not supported as an inpaint model. Where did you get this?") + + # replace masked region with specified values + mask_values = torch.tensor(latent_zeros).view(1, 4, 1, 1).expand_as(latents).to(device=latents.device, dtype=latents.dtype) + small_mask = self.shrink_mask(mask, 1) #make the synthetic mask fill in the masked_latents smaller than the mask channel + masked_latents = self.scheduler.scale_model_input(torch.where(small_mask == 0, mask_values, self.orig_latents), t) + else: + masked_latents = self.scheduler.scale_model_input(self.masked_latents,t) + + + masked_latents = einops.repeat(masked_latents, "b c h w -> (repeat b) c h w", repeat=latents.size(0)) + mask = einops.repeat(mask, "b c h w -> (repeat b) c h w", repeat=latents.size(0)) + model_input = torch.cat([latents, 1 - mask, masked_latents], dim=1).to(dtype=latents.dtype, device=latents.device) + return model_input + + def modify_result_before_callback(self, step_output, t) -> torch.Tensor: + """Fix preview images to show the original image in the unmasked region""" + if hasattr(step_output, "denoised"): #LCM Sampler + prediction = step_output.denoised + elif hasattr(step_output, "pred_original_sample"): #Samplers with final predictions + prediction = step_output.pred_original_sample + else: #all other samplers (no prediction available) + prediction = step_output.prev_sample + + mask = self.mask_from_timestep(t) + mask = einops.repeat(self.mask, "b c h w -> (repeat b) c h w", repeat=prediction.size(0)) + step_output.pred_original_sample = torch.lerp(prediction, self.orig_latents.to(dtype=prediction.dtype), mask.to(dtype=prediction.dtype)) + + return step_output + + def modify_latents_after_denoising(self, latents: torch.Tensor) -> torch.Tensor: + """Apply original unmasked to denoised latents""" + if self.inpaint_model: + if self.masked_latents is None: + mask = self.shrink_mask(self.mask, 1) + else: + return latents + else: + mask = self.mask_from_timestep(torch.Tensor([0])) + mask = einops.repeat(mask, "b c h w -> (repeat b) c h w", repeat=latents.size(0)) + latents = torch.lerp(latents, self.orig_latents.to(dtype=latents.dtype), mask.to(dtype=latents.dtype)).to(device=latents.device) + return latents + def trim_to_multiple_of(*args, multiple_of=8): return tuple((x - x % multiple_of) for x in args) @@ -318,26 +410,6 @@ def latents_from_embeddings( # latents = noise * self.scheduler.init_noise_sigma # it's like in t2l according to diffusers latents = self.scheduler.add_noise(latents, noise, batched_t) - if mask is not None: - if is_inpainting_model(self.unet): - if masked_latents is None: - raise Exception("Source image required for inpaint mask when inpaint model used!") - - self.invokeai_diffuser.model_forward_callback = AddsMaskLatents( - self._unet_forward, mask, masked_latents - ) - else: - # if no noise provided, noisify unmasked area based on seed - if noise is None: - noise = torch.randn( - orig_latents.shape, - dtype=torch.float32, - device="cpu", - generator=torch.Generator(device="cpu").manual_seed(seed), - ).to(device=orig_latents.device, dtype=orig_latents.dtype) - - additional_guidance.append(AddsMaskGuidance(mask, orig_latents, self.scheduler, noise, gradient_mask)) - try: latents = self.generate_latents_from_embeddings( latents, @@ -355,13 +427,8 @@ def latents_from_embeddings( # restore unmasked part after the last step is completed # in-process masking happens before each step - if mask is not None: - if gradient_mask: - latents = torch.where(mask > 0, latents, orig_latents) - else: - latents = torch.lerp( - orig_latents, latents.to(dtype=orig_latents.dtype), mask.to(dtype=orig_latents.dtype) - ) + for guidance in additional_guidance: + latents = guidance.modify_latents_after_denoising(latents) return latents @@ -426,6 +493,10 @@ def generate_latents_from_embeddings( ip_adapter_data=ip_adapter_data, t2i_adapter_data=t2i_adapter_data, ) + + for guidance in additional_guidance: #fix preview images to show original image in unmasked region + step_output = guidance.modify_result_before_callback(step_output, t) + latents = step_output.prev_sample predicted_original = getattr(step_output, "pred_original_sample", None) @@ -462,9 +533,8 @@ def step( if additional_guidance is None: additional_guidance = [] - # one day we will expand this extension point, but for now it just does denoise masking - for guidance in additional_guidance: - latents = guidance(latents, timestep) + for guidance in additional_guidance: #apply denoise mask based on unscaled input latents + latents = guidance.modify_latents_before_scaling(latents, timestep) # TODO: should this scaling happen here or inside self._unet_forward? # i.e. before or after passing it to InvokeAIDiffuserComponent @@ -512,6 +582,9 @@ def step( down_intrablock_additional_residuals = accum_adapter_state + for guidance in additional_guidance: #add mask channels for inpaint models + latent_model_input = guidance.modify_latents_before_noise_prediction(latent_model_input, timestep) + uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( sample=latent_model_input, timestep=t, # TODO: debug how handled batched and non batched timesteps @@ -538,19 +611,7 @@ def step( ) # compute the previous noisy sample x_t -> x_t-1 - step_output = self.scheduler.step(noise_pred, timestep, latents, **scheduler_step_kwargs) - - # TODO: discuss injection point options. For now this is a patch to get progress images working with inpainting again. - for guidance in additional_guidance: - # apply the mask to any "denoised" or "pred_original_sample" fields - if hasattr(step_output, "denoised"): - step_output.pred_original_sample = guidance(step_output.denoised, self.scheduler.timesteps[-1]) - elif hasattr(step_output, "pred_original_sample"): - step_output.pred_original_sample = guidance( - step_output.pred_original_sample, self.scheduler.timesteps[-1] - ) - else: - step_output.pred_original_sample = guidance(latents, self.scheduler.timesteps[-1]) + step_output = self.scheduler.step(noise_pred, timestep, latents, **conditioning_data.scheduler_args) return step_output @@ -573,17 +634,6 @@ def _unet_forward( **kwargs, ): """predict the noise residual""" - if is_inpainting_model(self.unet) and latents.size(1) == 4: - # Pad out normal non-inpainting inputs for an inpainting model. - # FIXME: There are too many layers of functions and we have too many different ways of - # overriding things! This should get handled in a way more consistent with the other - # use of AddsMaskLatents. - latents = AddsMaskLatents( - self._unet_forward, - mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype), - initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype), - ).add_mask_channels(latents) - # First three args should be positional, not keywords, so torch hooks can see them. return self.unet( latents,