Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion modules/impact/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def enhance_detail(image, model, clip, vae, guide_size, guide_size_for_bbox, max
refiner_ratio=None, refiner_model=None, refiner_clip=None, refiner_positive=None,
refiner_negative=None, control_net_wrapper=None, cycle=1,
inpaint_model=False, noise_mask_feather=0, scheduler_func=None,
vae_tiled_encode=False, vae_tiled_decode=False):
vae_tiled_encode=False, vae_tiled_decode=False, return_by_cycle_step=False):

if noise_mask is not None:
noise_mask = utils.tensor_gaussian_blur_mask(noise_mask, noise_mask_feather)
Expand All @@ -277,6 +277,9 @@ def enhance_detail(image, model, clip, vae, guide_size, guide_size_for_bbox, max
elif 'pooled_output' in positive[0][1]:
del positive[0][1]['pooled_output']

refined_latents_by_step:list[torch.Tensor] = []
refined_images_by_step:list[torch.Tensor] = []

h = image.shape[1]
w = image.shape[2]

Expand Down Expand Up @@ -363,6 +366,7 @@ def enhance_detail(image, model, clip, vae, guide_size, guide_size_for_bbox, max
sampler_opt = detailer_hook.get_custom_sampler()

# ksampler

for i in range(0, cycle):
if detailer_hook is not None:
if detailer_hook is not None:
Expand All @@ -384,42 +388,54 @@ def enhance_detail(image, model, clip, vae, guide_size, guide_size_for_bbox, max
refined_latent, denoise2, refiner_ratio, refiner_model, refiner_clip, refiner_positive, refiner_negative,
noise=noise, scheduler_func=scheduler_func, sampler_opt=sampler_opt)

if return_by_cycle_step: refined_latents_by_step.append(refined_latent)

if detailer_hook is not None:
refined_latent = detailer_hook.pre_decode(refined_latent)
if return_by_cycle_step: refined_latents_by_step = [detailer_hook.pre_decode(latent) for latent in refined_latents_by_step]

# non-latent downscale - latent downscale cause bad quality
start = time.time()
if vae_tiled_decode:
(refined_image,) = nodes.VAEDecodeTiled().decode(vae, refined_latent, 512) # using default settings
if return_by_cycle_step: refined_images_by_step = [nodes.VAEDecodeTiled().decode(vae, refined_latent, 512) for refined_latent in refined_latents_by_step]
logging.info(f"[Impact Pack] vae decoded (tiled) in {time.time() - start:.1f}s")
else:
try:
refined_image = vae.decode(refined_latent['samples'])
if return_by_cycle_step: refined_images_by_step = [vae.decode(refined_latent['samples']) for refined_latent in refined_latents_by_step]
except Exception:
# usually an out-of-memory exception from the decode, so try a tiled approach
logging.warning(f"[Impact Pack] failed after {time.time() - start:.1f}s, doing vae.decode_tiled 64...")
refined_image = vae.decode_tiled(refined_latent["samples"], tile_x=64, tile_y=64, )
if return_by_cycle_step: refined_images_by_step = [vae.decode_tiled(refined_latent["samples"], tile_x=64, tile_y=64) for refined_latent in refined_latents_by_step]
logging.info(f"[Impact Pack] vae decoded in {time.time() - start:.1f}s")
else:
# skipped
refined_image = upscaled_image
if return_by_cycle_step: refined_images_by_step = [upscaled_image,]

if detailer_hook is not None:
refined_image = detailer_hook.post_decode(refined_image)
if return_by_cycle_step: refined_images_by_step = [detailer_hook.post_decode(refined_image) for refined_image in refined_images_by_step]

# downscale

# workaround: support WAN as an i2i model
if len(refined_image.shape) == 5:
refined_image = refined_image.squeeze(0)
if return_by_cycle_step: refined_images_by_step = [ i.squeeze(0) for i in refined_images_by_step ]

refined_image = utils.tensor_resize(refined_image, w, h)
if return_by_cycle_step: refined_images_by_step = [ utils.tensor_resize(i,w,h) for i in refined_images_by_step ]

# prevent mixing of device
refined_image = refined_image.cpu()
if return_by_cycle_step: refined_images_by_step = [ i.cpu() for i in refined_images_by_step ]

# don't convert to latent - latent break image
# preserving pil is much better
if return_by_cycle_step: return refined_image, cnet_pils, refined_images_by_step
return refined_image, cnet_pils


Expand Down
Loading