Skip to content

Commit 77ca55d

Browse files
committed
Add Tiled Image Upscale (With Model)
1 parent 5401aba commit 77ca55d

File tree

2 files changed

+316
-1
lines changed

2 files changed

+316
-1
lines changed

nodes/TiledUpscaleModel.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
import torch
2+
3+
import comfy.utils
4+
from comfy import model_management
5+
6+
7+
@torch.inference_mode()
8+
def dynamic_tiled_upscale_with_custom_feather(
9+
samples,
10+
function,
11+
tile_size=512,
12+
overlap=32,
13+
output_device="cpu",
14+
pbar=None,
15+
feather=0,
16+
target_height=None,
17+
target_width=None,
18+
resample_method="lanczos",
19+
):
20+
if samples.ndim != 4:
21+
raise ValueError("Expected samples with shape [B, C, H, W].")
22+
23+
batch_size, channels, in_height, in_width = samples.shape
24+
25+
if target_height is None or target_width is None:
26+
raise ValueError("target_height and target_width must be provided.")
27+
28+
tile_size = int(tile_size)
29+
if tile_size <= 0:
30+
raise ValueError("tile_size must be positive.")
31+
32+
overlap = max(0, int(overlap))
33+
if overlap >= tile_size:
34+
overlap = tile_size - 1 if tile_size > 1 else 0
35+
36+
tile_step = tile_size - overlap if tile_size > overlap else tile_size
37+
38+
scale_y_global = float(target_height) / float(in_height)
39+
scale_x_global = float(target_width) / float(in_width)
40+
41+
blended_output = None
42+
43+
for batch_index in range(batch_size):
44+
source = samples[batch_index : batch_index + 1]
45+
46+
output_for_batch = None
47+
weight_for_batch = None
48+
49+
y_position = 0
50+
while y_position < in_height:
51+
x_position = 0
52+
while x_position < in_width:
53+
y_end = min(y_position + tile_size, in_height)
54+
x_end = min(x_position + tile_size, in_width)
55+
56+
tile_source = source[:, :, y_position:y_end, x_position:x_end]
57+
tile_output_native = function(tile_source).to(output_device)
58+
59+
if output_for_batch is None:
60+
out_channels = tile_output_native.shape[1]
61+
output_for_batch = torch.zeros(
62+
(1, out_channels, target_height, target_width),
63+
device=output_device,
64+
dtype=tile_output_native.dtype,
65+
)
66+
weight_for_batch = torch.zeros_like(output_for_batch)
67+
68+
if blended_output is None:
69+
blended_output = torch.zeros(
70+
(batch_size, out_channels, target_height, target_width),
71+
device=output_device,
72+
dtype=tile_output_native.dtype,
73+
)
74+
75+
out_y_start = int(round(y_position * target_height / in_height))
76+
out_y_end = int(round(y_end * target_height / in_height))
77+
out_x_start = int(round(x_position * target_width / in_width))
78+
out_x_end = int(round(x_end * target_width / in_width))
79+
80+
tile_target_height = max(1, out_y_end - out_y_start)
81+
tile_target_width = max(1, out_x_end - out_x_start)
82+
83+
if (
84+
tile_output_native.shape[2] != tile_target_height
85+
or tile_output_native.shape[3] != tile_target_width
86+
):
87+
tile_output = comfy.utils.common_upscale(
88+
tile_output_native,
89+
tile_target_width,
90+
tile_target_height,
91+
resample_method,
92+
"disabled",
93+
)
94+
else:
95+
tile_output = tile_output_native
96+
97+
mask = torch.ones_like(tile_output)
98+
99+
if feather is None or feather <= 0:
100+
feather_pixels_y = int(round(overlap * scale_y_global))
101+
feather_pixels_x = int(round(overlap * scale_x_global))
102+
else:
103+
feather_pixels_y = int(feather)
104+
feather_pixels_x = int(feather)
105+
106+
if feather_pixels_y > 0:
107+
max_vertical = tile_output.shape[2] // 2
108+
feather_pixels_y = min(feather_pixels_y, max_vertical)
109+
for t in range(feather_pixels_y):
110+
weight_value = float(t + 1) / float(feather_pixels_y)
111+
row_start = t
112+
row_end = t + 1
113+
inv_row_start = tile_output.shape[2] - 1 - t
114+
inv_row_end = tile_output.shape[2] - t
115+
mask[:, :, row_start:row_end, :].mul_(weight_value)
116+
mask[:, :, inv_row_start:inv_row_end, :].mul_(weight_value)
117+
118+
if feather_pixels_x > 0:
119+
max_horizontal = tile_output.shape[3] // 2
120+
feather_pixels_x = min(feather_pixels_x, max_horizontal)
121+
for t in range(feather_pixels_x):
122+
weight_value = float(t + 1) / float(feather_pixels_x)
123+
col_start = t
124+
col_end = t + 1
125+
inv_col_start = tile_output.shape[3] - 1 - t
126+
inv_col_end = tile_output.shape[3] - t
127+
mask[:, :, :, col_start:col_end].mul_(weight_value)
128+
mask[:, :, :, inv_col_start:inv_col_end].mul_(weight_value)
129+
130+
out_y_end = out_y_start + tile_output.shape[2]
131+
out_x_end = out_x_start + tile_output.shape[3]
132+
133+
output_for_batch[:, :, out_y_start:out_y_end, out_x_start:out_x_end] += (
134+
tile_output * mask
135+
)
136+
weight_for_batch[:, :, out_y_start:out_y_end, out_x_start:out_x_end] += mask
137+
138+
if pbar is not None:
139+
pbar.update(1)
140+
141+
x_position += tile_step
142+
y_position += tile_step
143+
144+
weight_for_batch = torch.where(
145+
weight_for_batch == 0.0,
146+
torch.ones_like(weight_for_batch),
147+
weight_for_batch,
148+
)
149+
output_for_batch = output_for_batch / weight_for_batch
150+
151+
blended_output[batch_index : batch_index + 1] = output_for_batch
152+
153+
return blended_output
154+
155+
156+
class WASTiledImageUpscaleWithModel:
157+
@classmethod
158+
def INPUT_TYPES(cls):
159+
return {
160+
"required": {
161+
"upscale_model": ("UPSCALE_MODEL", {}),
162+
"image": ("IMAGE", {}),
163+
"upscale_factor": (
164+
"FLOAT",
165+
{
166+
"default": 4.0,
167+
"min": 1.0,
168+
"max": 16.0,
169+
"step": 0.1,
170+
"tooltip": "Final scale relative to input image size. Output resolution ~= input * upscale_factor.",
171+
},
172+
),
173+
"tile_size": (
174+
"INT",
175+
{
176+
"default": 512,
177+
"min": 64,
178+
"max": 4096,
179+
"step": 16,
180+
"tooltip": "Tile size in input pixels. Larger tiles are faster but use more VRAM.",
181+
},
182+
),
183+
"overlap": (
184+
"INT",
185+
{
186+
"default": 32,
187+
"min": 0,
188+
"max": 1024,
189+
"step": 1,
190+
"tooltip": "Tile overlap in input pixels. Higher overlap reduces seams but increases compute.",
191+
},
192+
),
193+
"feather": (
194+
"INT",
195+
{
196+
"default": 0,
197+
"min": 0,
198+
"max": 4096,
199+
"step": 1,
200+
"tooltip": "Feather width in output pixels for tile blending. 0 = auto from overlap.",
201+
},
202+
),
203+
"resample_method": (
204+
[
205+
"nearest-exact",
206+
"bilinear",
207+
"area",
208+
"bicubic",
209+
"lanczos",
210+
],
211+
{
212+
"default": "lanczos",
213+
"tooltip": "Resampling kernel used to reach the final upscale_factor resolution.",
214+
},
215+
),
216+
}
217+
}
218+
219+
RETURN_TYPES = ("IMAGE",)
220+
FUNCTION = "upscale"
221+
CATEGORY = "image/upscaling"
222+
223+
def upscale(
224+
self,
225+
upscale_model,
226+
image,
227+
upscale_factor,
228+
tile_size,
229+
overlap,
230+
feather,
231+
resample_method,
232+
):
233+
device = model_management.get_torch_device()
234+
235+
scale_estimate = getattr(upscale_model, "scale", 4.0)
236+
element_size = image.element_size()
237+
238+
memory_required = model_management.module_size(upscale_model.model)
239+
memory_required += (
240+
tile_size * tile_size * 3
241+
) * element_size * max(scale_estimate, 1.0) * 384.0
242+
memory_required += image.nelement() * element_size
243+
244+
model_management.free_memory(memory_required, device)
245+
246+
upscale_model.to(device)
247+
248+
batch_size, in_h, in_w, _ = image.shape
249+
250+
upscale_factor = float(upscale_factor)
251+
if upscale_factor < 1.0:
252+
upscale_factor = 1.0
253+
254+
target_height = max(1, int(round(in_h * upscale_factor)))
255+
target_width = max(1, int(round(in_w * upscale_factor)))
256+
257+
input_image = image.movedim(-1, -3).to(device)
258+
259+
current_tile_size = int(tile_size)
260+
minimum_tile_size = 64
261+
262+
upscale_result = None
263+
output_device = device
264+
265+
oom = True
266+
last_exception = None
267+
268+
while oom:
269+
try:
270+
steps = input_image.shape[0] * comfy.utils.get_tiled_scale_steps(
271+
input_image.shape[3],
272+
input_image.shape[2],
273+
tile_x=current_tile_size,
274+
tile_y=current_tile_size,
275+
overlap=overlap,
276+
)
277+
progress = comfy.utils.ProgressBar(steps)
278+
279+
upscale_result = dynamic_tiled_upscale_with_custom_feather(
280+
samples=input_image,
281+
function=lambda a: upscale_model(a),
282+
tile_size=current_tile_size,
283+
overlap=overlap,
284+
output_device=output_device,
285+
pbar=progress,
286+
feather=feather,
287+
target_height=target_height,
288+
target_width=target_width,
289+
resample_method=resample_method,
290+
)
291+
292+
oom = False
293+
except model_management.OOM_EXCEPTION as exception:
294+
last_exception = exception
295+
current_tile_size //= 2
296+
if current_tile_size < minimum_tile_size:
297+
upscale_model.to("cpu")
298+
raise last_exception
299+
300+
upscale_model.to("cpu")
301+
302+
upscale_result = torch.clamp(
303+
upscale_result.movedim(-3, -1), min=0.0, max=1.0
304+
)
305+
306+
return (upscale_result,)
307+
308+
309+
NODE_CLASS_MAPPINGS = {
310+
"WASTiledImageUpscaleWithModel": WASTiledImageUpscaleWithModel,
311+
}
312+
313+
NODE_DISPLAY_NAME_MAPPINGS = {
314+
"WASTiledImageUpscaleWithModel": "Tiled Image Upscale (With Model)",
315+
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "was-extras"
3-
version = "1.1.0"
3+
version = "1.2.0"
44
description = "A collection of experimental WAS nodes and utilities for ComfyUI."
55
readme = "README.md"
66
requires-python = ">=3.10"

0 commit comments

Comments
 (0)