@@ -39,6 +39,7 @@ def __init__(self, start: float = 0.0, stop: float = 1.0):
3939 self ._timestep : torch .LongTensor = None
4040 self ._preds : Dict [str , torch .Tensor ] = {}
4141 self ._num_outputs_prepared : int = 0
42+ self ._enabled = True
4243
4344 if not (0.0 <= start < 1.0 ):
4445 raise ValueError (
@@ -54,6 +55,12 @@ def __init__(self, start: float = 0.0, stop: float = 1.0):
5455 "`_input_predictions` must be a list of required prediction names for the guidance technique."
5556 )
5657
58+ def force_disable (self ):
59+ self ._enabled = False
60+
61+ def force_enable (self ):
62+ self ._enabled = True
63+
5764 def set_state (self , step : int , num_inference_steps : int , timestep : torch .LongTensor ) -> None :
5865 self ._step = step
5966 self ._num_inference_steps = num_inference_steps
@@ -62,10 +69,10 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen
6269 self ._num_outputs_prepared = 0
6370
6471 def prepare_inputs (self , denoiser : torch .nn .Module , * args : Union [Tuple [torch .Tensor ], List [torch .Tensor ]]) -> Tuple [List [torch .Tensor ], ...]:
65- raise NotImplementedError ("GuidanceMixin ::prepare_inputs must be implemented in subclasses." )
72+ raise NotImplementedError ("BaseGuidance ::prepare_inputs must be implemented in subclasses." )
6673
6774 def prepare_outputs (self , denoiser : torch .nn .Module , pred : torch .Tensor ) -> None :
68- raise NotImplementedError ("GuidanceMixin ::prepare_outputs must be implemented in subclasses." )
75+ raise NotImplementedError ("BaseGuidance ::prepare_outputs must be implemented in subclasses." )
6976
7077 def __call__ (self , ** kwargs ) -> Any :
7178 if len (kwargs ) != self .num_conditions :
@@ -75,11 +82,19 @@ def __call__(self, **kwargs) -> Any:
7582 return self .forward (** kwargs )
7683
7784 def forward (self , * args , ** kwargs ) -> Any :
78- raise NotImplementedError ("GuidanceMixin ::forward must be implemented in subclasses." )
85+ raise NotImplementedError ("BaseGuidance ::forward must be implemented in subclasses." )
7986
87+ @property
88+ def is_conditional (self ) -> bool :
89+ raise NotImplementedError ("BaseGuidance::is_conditional must be implemented in subclasses." )
90+
91+ @property
92+ def is_unconditional (self ) -> bool :
93+ return not self .is_conditional
94+
8095 @property
8196 def num_conditions (self ) -> int :
82- raise NotImplementedError ("GuidanceMixin ::num_conditions must be implemented in subclasses." )
97+ raise NotImplementedError ("BaseGuidance ::num_conditions must be implemented in subclasses." )
8398
8499 @property
85100 def outputs (self ) -> Dict [str , torch .Tensor ]:
@@ -114,7 +129,7 @@ def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *arg
114129 """
115130 Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly
116131 prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the
117- `GuidanceMixin ` class.
132+ `BaseGuidance ` class.
118133
119134 Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements:
120135 - The first element is the conditional input.
0 commit comments