| 
 | 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved.  | 
 | 2 | +#  | 
 | 3 | +# Licensed under the Apache License, Version 2.0 (the "License");  | 
 | 4 | +# you may not use this file except in compliance with the License.  | 
 | 5 | +# You may obtain a copy of the License at  | 
 | 6 | +#  | 
 | 7 | +#     http://www.apache.org/licenses/LICENSE-2.0  | 
 | 8 | +#  | 
 | 9 | +# Unless required by applicable law or agreed to in writing, software  | 
 | 10 | +# distributed under the License is distributed on an "AS IS" BASIS,  | 
 | 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.  | 
 | 12 | +# See the License for the specific language governing permissions and  | 
 | 13 | +# limitations under the License.  | 
 | 14 | + | 
 | 15 | +from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union  | 
 | 16 | + | 
 | 17 | +import torch  | 
 | 18 | + | 
 | 19 | +from ..utils import get_logger  | 
 | 20 | + | 
 | 21 | + | 
 | 22 | +if TYPE_CHECKING:  | 
 | 23 | +    from ..models.attention_processor import AttentionProcessor  | 
 | 24 | + | 
 | 25 | + | 
 | 26 | +logger = get_logger(__name__)  # pylint: disable=invalid-name  | 
 | 27 | + | 
 | 28 | + | 
 | 29 | +class BaseGuidance:  | 
 | 30 | +    r"""Base class providing the skeleton for implementing guidance techniques."""  | 
 | 31 | + | 
 | 32 | +    _input_predictions = None  | 
 | 33 | + | 
 | 34 | +    def __init__(self, start: float = 0.0, stop: float = 1.0):  | 
 | 35 | +        self._start = start  | 
 | 36 | +        self._stop = stop  | 
 | 37 | +        self._step: int = None  | 
 | 38 | +        self._num_inference_steps: int = None  | 
 | 39 | +        self._timestep: torch.LongTensor = None  | 
 | 40 | +        self._preds: Dict[str, torch.Tensor] = {}  | 
 | 41 | +        self._num_outputs_prepared: int = 0  | 
 | 42 | + | 
 | 43 | +        if not (0.0 <= start < 1.0):  | 
 | 44 | +            raise ValueError(  | 
 | 45 | +                f"Expected `start` to be between 0.0 and 1.0, but got {start}."  | 
 | 46 | +            )  | 
 | 47 | +        if not (start <= stop <= 1.0):  | 
 | 48 | +            raise ValueError(  | 
 | 49 | +                f"Expected `stop` to be between {start} and 1.0, but got {stop}."  | 
 | 50 | +            )  | 
 | 51 | + | 
 | 52 | +        if self._input_predictions is None or not isinstance(self._input_predictions, list):  | 
 | 53 | +            raise ValueError(  | 
 | 54 | +                "`_input_predictions` must be a list of required prediction names for the guidance technique."  | 
 | 55 | +            )  | 
 | 56 | + | 
 | 57 | +    def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:  | 
 | 58 | +        self._step = step  | 
 | 59 | +        self._num_inference_steps = num_inference_steps  | 
 | 60 | +        self._timestep = timestep  | 
 | 61 | +        self._preds = {}  | 
 | 62 | +        self._num_outputs_prepared = 0  | 
 | 63 | + | 
 | 64 | +    def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:  | 
 | 65 | +        raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.")  | 
 | 66 | + | 
 | 67 | +    def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:  | 
 | 68 | +        raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.")  | 
 | 69 | + | 
 | 70 | +    def __call__(self, **kwargs) -> Any:  | 
 | 71 | +        if len(kwargs) != self.num_conditions:  | 
 | 72 | +            raise ValueError(  | 
 | 73 | +                f"Expected {self.num_conditions} arguments, but got {len(kwargs)}. Please provide the correct number of arguments."  | 
 | 74 | +            )  | 
 | 75 | +        return self.forward(**kwargs)  | 
 | 76 | + | 
 | 77 | +    def forward(self, *args, **kwargs) -> Any:  | 
 | 78 | +        raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")  | 
 | 79 | + | 
 | 80 | +    @property  | 
 | 81 | +    def num_conditions(self) -> int:  | 
 | 82 | +        raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")  | 
 | 83 | + | 
 | 84 | +    @property  | 
 | 85 | +    def outputs(self) -> Dict[str, torch.Tensor]:  | 
 | 86 | +        return self._preds  | 
 | 87 | + | 
 | 88 | + | 
 | 89 | +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):  | 
 | 90 | +    r"""  | 
 | 91 | +    Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on  | 
 | 92 | +    Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are  | 
 | 93 | +    Flawed](https://arxiv.org/pdf/2305.08891.pdf).  | 
 | 94 | +    Args:  | 
 | 95 | +        noise_cfg (`torch.Tensor`):  | 
 | 96 | +            The predicted noise tensor for the guided diffusion process.  | 
 | 97 | +        noise_pred_text (`torch.Tensor`):  | 
 | 98 | +            The predicted noise tensor for the text-guided diffusion process.  | 
 | 99 | +        guidance_rescale (`float`, *optional*, defaults to 0.0):  | 
 | 100 | +            A rescale factor applied to the noise predictions.  | 
 | 101 | +    Returns:  | 
 | 102 | +        noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.  | 
 | 103 | +    """  | 
 | 104 | +    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)  | 
 | 105 | +    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)  | 
 | 106 | +    # rescale the results from guidance (fixes overexposure)  | 
 | 107 | +    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)  | 
 | 108 | +    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images  | 
 | 109 | +    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg  | 
 | 110 | +    return noise_cfg  | 
 | 111 | + | 
 | 112 | + | 
 | 113 | +def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:  | 
 | 114 | +    """  | 
 | 115 | +    Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly  | 
 | 116 | +    prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the  | 
 | 117 | +    `GuidanceMixin` class.  | 
 | 118 | +
  | 
 | 119 | +    Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements:  | 
 | 120 | +    - The first element is the conditional input.  | 
 | 121 | +    - The second element is the unconditional input or None.  | 
 | 122 | +      | 
 | 123 | +    If only the conditional input is provided, it will be repeated for all batches.  | 
 | 124 | +      | 
 | 125 | +    If both conditional and unconditional inputs are provided, they are alternated as batches of data.  | 
 | 126 | +    """  | 
 | 127 | +    list_of_inputs = []  | 
 | 128 | +    for arg in args:  | 
 | 129 | +        if arg is None or isinstance(arg, torch.Tensor):  | 
 | 130 | +            list_of_inputs.append([arg] * num_conditions)  | 
 | 131 | +        elif isinstance(arg, (tuple, list)):  | 
 | 132 | +            if len(arg) != 2:  | 
 | 133 | +                raise ValueError(  | 
 | 134 | +                    f"Expected a tuple or list of length 2, but got {len(arg)} for argument {arg}. Please provide a tuple/list of length 2 "  | 
 | 135 | +                    f"with the first element being the conditional input and the second element being the unconditional input or None."  | 
 | 136 | +                )  | 
 | 137 | +            if arg[1] is None:  | 
 | 138 | +                # Only conditioning inputs for all batches  | 
 | 139 | +                list_of_inputs.append([arg[0]] * num_conditions)  | 
 | 140 | +            else:  | 
 | 141 | +                # Alternating conditional and unconditional inputs as batches  | 
 | 142 | +                inputs = [arg[i % 2] for i in range(num_conditions)]  | 
 | 143 | +                list_of_inputs.append(inputs)  | 
 | 144 | +        else:  | 
 | 145 | +            raise ValueError(  | 
 | 146 | +                f"Expected a tensor, tuple, or list, but got {type(arg)} for argument {arg}. Please provide a tensor, tuple, or list."  | 
 | 147 | +            )  | 
 | 148 | +    return tuple(list_of_inputs)  | 
0 commit comments