Skip to content

Commit 8ae0ea9

Browse files
authored
Add callback to sd_samplers
1 parent 8906be8 commit 8ae0ea9

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

modules/sd_samplers.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from modules.shared import opts, cmd_opts, state
1313
import modules.shared as shared
14+
from modules.script_callbacks import CGFDenoiserParams, cfg_denoiser_callback
1415

1516

1617
SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])
@@ -278,6 +279,8 @@ def forward(self, x, sigma, uncond, cond, cond_scale, image_cond):
278279
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_cond])
279280
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
280281

282+
cfg_denoiser_callback(CGFDenoiserParams(x_in, image_cond_in, sigma_in, state.sampling_step, state.sampling_steps))
283+
281284
if tensor.shape[1] == uncond.shape[1]:
282285
cond_in = torch.cat([tensor, uncond])
283286

0 commit comments

Comments
 (0)