Skip to content

Commit 9da8a9d

Browse files
committed
support sdxl controlnet
1 parent 0c4c1a8 commit 9da8a9d

File tree

4 files changed

+119
-120
lines changed

4 files changed

+119
-120
lines changed

src/diffusers/guiders/classifier_free_guidance.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,10 @@ def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] =
9292

9393
return pred
9494

95+
@property
96+
def is_conditional(self) -> bool:
97+
return self._num_outputs_prepared == 0
98+
9599
@property
96100
def num_conditions(self) -> int:
97101
num_conditions = 1
@@ -100,6 +104,8 @@ def num_conditions(self) -> int:
100104
return num_conditions
101105

102106
def _is_cfg_enabled(self) -> bool:
107+
if not self._enabled:
108+
return False
103109
skip_start_step = int(self._start * self._num_inference_steps)
104110
skip_stop_step = int(self._stop * self._num_inference_steps)
105111
is_within_range = skip_start_step <= self._step < skip_stop_step

src/diffusers/guiders/guider_utils.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(self, start: float = 0.0, stop: float = 1.0):
3939
self._timestep: torch.LongTensor = None
4040
self._preds: Dict[str, torch.Tensor] = {}
4141
self._num_outputs_prepared: int = 0
42+
self._enabled = True
4243

4344
if not (0.0 <= start < 1.0):
4445
raise ValueError(
@@ -54,6 +55,12 @@ def __init__(self, start: float = 0.0, stop: float = 1.0):
5455
"`_input_predictions` must be a list of required prediction names for the guidance technique."
5556
)
5657

58+
def force_disable(self):
59+
self._enabled = False
60+
61+
def force_enable(self):
62+
self._enabled = True
63+
5764
def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTensor) -> None:
5865
self._step = step
5966
self._num_inference_steps = num_inference_steps
@@ -62,10 +69,10 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen
6269
self._num_outputs_prepared = 0
6370

6471
def prepare_inputs(self, denoiser: torch.nn.Module, *args: Union[Tuple[torch.Tensor], List[torch.Tensor]]) -> Tuple[List[torch.Tensor], ...]:
65-
raise NotImplementedError("GuidanceMixin::prepare_inputs must be implemented in subclasses.")
72+
raise NotImplementedError("BaseGuidance::prepare_inputs must be implemented in subclasses.")
6673

6774
def prepare_outputs(self, denoiser: torch.nn.Module, pred: torch.Tensor) -> None:
68-
raise NotImplementedError("GuidanceMixin::prepare_outputs must be implemented in subclasses.")
75+
raise NotImplementedError("BaseGuidance::prepare_outputs must be implemented in subclasses.")
6976

7077
def __call__(self, **kwargs) -> Any:
7178
if len(kwargs) != self.num_conditions:
@@ -75,11 +82,19 @@ def __call__(self, **kwargs) -> Any:
7582
return self.forward(**kwargs)
7683

7784
def forward(self, *args, **kwargs) -> Any:
78-
raise NotImplementedError("GuidanceMixin::forward must be implemented in subclasses.")
85+
raise NotImplementedError("BaseGuidance::forward must be implemented in subclasses.")
7986

87+
@property
88+
def is_conditional(self) -> bool:
89+
raise NotImplementedError("BaseGuidance::is_conditional must be implemented in subclasses.")
90+
91+
@property
92+
def is_unconditional(self) -> bool:
93+
return not self.is_conditional
94+
8095
@property
8196
def num_conditions(self) -> int:
82-
raise NotImplementedError("GuidanceMixin::num_conditions must be implemented in subclasses.")
97+
raise NotImplementedError("BaseGuidance::num_conditions must be implemented in subclasses.")
8398

8499
@property
85100
def outputs(self) -> Dict[str, torch.Tensor]:
@@ -114,7 +129,7 @@ def _default_prepare_inputs(denoiser: torch.nn.Module, num_conditions: int, *arg
114129
"""
115130
Prepares the inputs for the denoiser by ensuring that the conditional and unconditional inputs are correctly
116131
prepared based on required number of conditions. This function is used in the `prepare_inputs` method of the
117-
`GuidanceMixin` class.
132+
`BaseGuidance` class.
118133
119134
Either tensors or tuples/lists of tensors can be provided. If a tuple/list is provided, it should contain two elements:
120135
- The first element is the conditional input.

src/diffusers/guiders/skip_layer_guidance.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,10 @@ def forward(
189189
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)
190190

191191
return pred
192+
193+
@property
194+
def is_conditional(self) -> bool:
195+
return self._num_outputs_prepared == 0 or self._num_outputs_prepared == 2
192196

193197
@property
194198
def num_conditions(self) -> int:
@@ -200,6 +204,8 @@ def num_conditions(self) -> int:
200204
return num_conditions
201205

202206
def _is_cfg_enabled(self) -> bool:
207+
if not self._enabled:
208+
return False
203209
skip_start_step = int(self._start * self._num_inference_steps)
204210
skip_stop_step = int(self._stop * self._num_inference_steps)
205211
is_within_range = skip_start_step <= self._step < skip_stop_step
@@ -211,6 +217,8 @@ def _is_cfg_enabled(self) -> bool:
211217
return is_within_range and not is_close
212218

213219
def _is_slg_enabled(self) -> bool:
220+
if not self._enabled:
221+
return False
214222
skip_start_step = int(self._start * self._num_inference_steps)
215223
skip_stop_step = int(self._stop * self._num_inference_steps)
216224
is_within_range = skip_start_step < self._step < skip_stop_step

0 commit comments

Comments
 (0)