|
| 1 | +import torch |
| 2 | + |
| 3 | +import comfy.utils |
| 4 | +from comfy import model_management |
| 5 | + |
| 6 | + |
| 7 | +@torch.inference_mode() |
| 8 | +def dynamic_tiled_upscale_with_custom_feather( |
| 9 | + samples, |
| 10 | + function, |
| 11 | + tile_size=512, |
| 12 | + overlap=32, |
| 13 | + output_device="cpu", |
| 14 | + pbar=None, |
| 15 | + feather=0, |
| 16 | + target_height=None, |
| 17 | + target_width=None, |
| 18 | + resample_method="lanczos", |
| 19 | +): |
| 20 | + if samples.ndim != 4: |
| 21 | + raise ValueError("Expected samples with shape [B, C, H, W].") |
| 22 | + |
| 23 | + batch_size, channels, in_height, in_width = samples.shape |
| 24 | + |
| 25 | + if target_height is None or target_width is None: |
| 26 | + raise ValueError("target_height and target_width must be provided.") |
| 27 | + |
| 28 | + tile_size = int(tile_size) |
| 29 | + if tile_size <= 0: |
| 30 | + raise ValueError("tile_size must be positive.") |
| 31 | + |
| 32 | + overlap = max(0, int(overlap)) |
| 33 | + if overlap >= tile_size: |
| 34 | + overlap = tile_size - 1 if tile_size > 1 else 0 |
| 35 | + |
| 36 | + tile_step = tile_size - overlap if tile_size > overlap else tile_size |
| 37 | + |
| 38 | + scale_y_global = float(target_height) / float(in_height) |
| 39 | + scale_x_global = float(target_width) / float(in_width) |
| 40 | + |
| 41 | + blended_output = None |
| 42 | + |
| 43 | + for batch_index in range(batch_size): |
| 44 | + source = samples[batch_index : batch_index + 1] |
| 45 | + |
| 46 | + output_for_batch = None |
| 47 | + weight_for_batch = None |
| 48 | + |
| 49 | + y_position = 0 |
| 50 | + while y_position < in_height: |
| 51 | + x_position = 0 |
| 52 | + while x_position < in_width: |
| 53 | + y_end = min(y_position + tile_size, in_height) |
| 54 | + x_end = min(x_position + tile_size, in_width) |
| 55 | + |
| 56 | + tile_source = source[:, :, y_position:y_end, x_position:x_end] |
| 57 | + tile_output_native = function(tile_source).to(output_device) |
| 58 | + |
| 59 | + if output_for_batch is None: |
| 60 | + out_channels = tile_output_native.shape[1] |
| 61 | + output_for_batch = torch.zeros( |
| 62 | + (1, out_channels, target_height, target_width), |
| 63 | + device=output_device, |
| 64 | + dtype=tile_output_native.dtype, |
| 65 | + ) |
| 66 | + weight_for_batch = torch.zeros_like(output_for_batch) |
| 67 | + |
| 68 | + if blended_output is None: |
| 69 | + blended_output = torch.zeros( |
| 70 | + (batch_size, out_channels, target_height, target_width), |
| 71 | + device=output_device, |
| 72 | + dtype=tile_output_native.dtype, |
| 73 | + ) |
| 74 | + |
| 75 | + out_y_start = int(round(y_position * target_height / in_height)) |
| 76 | + out_y_end = int(round(y_end * target_height / in_height)) |
| 77 | + out_x_start = int(round(x_position * target_width / in_width)) |
| 78 | + out_x_end = int(round(x_end * target_width / in_width)) |
| 79 | + |
| 80 | + tile_target_height = max(1, out_y_end - out_y_start) |
| 81 | + tile_target_width = max(1, out_x_end - out_x_start) |
| 82 | + |
| 83 | + if ( |
| 84 | + tile_output_native.shape[2] != tile_target_height |
| 85 | + or tile_output_native.shape[3] != tile_target_width |
| 86 | + ): |
| 87 | + tile_output = comfy.utils.common_upscale( |
| 88 | + tile_output_native, |
| 89 | + tile_target_width, |
| 90 | + tile_target_height, |
| 91 | + resample_method, |
| 92 | + "disabled", |
| 93 | + ) |
| 94 | + else: |
| 95 | + tile_output = tile_output_native |
| 96 | + |
| 97 | + mask = torch.ones_like(tile_output) |
| 98 | + |
| 99 | + if feather is None or feather <= 0: |
| 100 | + feather_pixels_y = int(round(overlap * scale_y_global)) |
| 101 | + feather_pixels_x = int(round(overlap * scale_x_global)) |
| 102 | + else: |
| 103 | + feather_pixels_y = int(feather) |
| 104 | + feather_pixels_x = int(feather) |
| 105 | + |
| 106 | + if feather_pixels_y > 0: |
| 107 | + max_vertical = tile_output.shape[2] // 2 |
| 108 | + feather_pixels_y = min(feather_pixels_y, max_vertical) |
| 109 | + for t in range(feather_pixels_y): |
| 110 | + weight_value = float(t + 1) / float(feather_pixels_y) |
| 111 | + row_start = t |
| 112 | + row_end = t + 1 |
| 113 | + inv_row_start = tile_output.shape[2] - 1 - t |
| 114 | + inv_row_end = tile_output.shape[2] - t |
| 115 | + mask[:, :, row_start:row_end, :].mul_(weight_value) |
| 116 | + mask[:, :, inv_row_start:inv_row_end, :].mul_(weight_value) |
| 117 | + |
| 118 | + if feather_pixels_x > 0: |
| 119 | + max_horizontal = tile_output.shape[3] // 2 |
| 120 | + feather_pixels_x = min(feather_pixels_x, max_horizontal) |
| 121 | + for t in range(feather_pixels_x): |
| 122 | + weight_value = float(t + 1) / float(feather_pixels_x) |
| 123 | + col_start = t |
| 124 | + col_end = t + 1 |
| 125 | + inv_col_start = tile_output.shape[3] - 1 - t |
| 126 | + inv_col_end = tile_output.shape[3] - t |
| 127 | + mask[:, :, :, col_start:col_end].mul_(weight_value) |
| 128 | + mask[:, :, :, inv_col_start:inv_col_end].mul_(weight_value) |
| 129 | + |
| 130 | + out_y_end = out_y_start + tile_output.shape[2] |
| 131 | + out_x_end = out_x_start + tile_output.shape[3] |
| 132 | + |
| 133 | + output_for_batch[:, :, out_y_start:out_y_end, out_x_start:out_x_end] += ( |
| 134 | + tile_output * mask |
| 135 | + ) |
| 136 | + weight_for_batch[:, :, out_y_start:out_y_end, out_x_start:out_x_end] += mask |
| 137 | + |
| 138 | + if pbar is not None: |
| 139 | + pbar.update(1) |
| 140 | + |
| 141 | + x_position += tile_step |
| 142 | + y_position += tile_step |
| 143 | + |
| 144 | + weight_for_batch = torch.where( |
| 145 | + weight_for_batch == 0.0, |
| 146 | + torch.ones_like(weight_for_batch), |
| 147 | + weight_for_batch, |
| 148 | + ) |
| 149 | + output_for_batch = output_for_batch / weight_for_batch |
| 150 | + |
| 151 | + blended_output[batch_index : batch_index + 1] = output_for_batch |
| 152 | + |
| 153 | + return blended_output |
| 154 | + |
| 155 | + |
| 156 | +class WASTiledImageUpscaleWithModel: |
| 157 | + @classmethod |
| 158 | + def INPUT_TYPES(cls): |
| 159 | + return { |
| 160 | + "required": { |
| 161 | + "upscale_model": ("UPSCALE_MODEL", {}), |
| 162 | + "image": ("IMAGE", {}), |
| 163 | + "upscale_factor": ( |
| 164 | + "FLOAT", |
| 165 | + { |
| 166 | + "default": 4.0, |
| 167 | + "min": 1.0, |
| 168 | + "max": 16.0, |
| 169 | + "step": 0.1, |
| 170 | + "tooltip": "Final scale relative to input image size. Output resolution ~= input * upscale_factor.", |
| 171 | + }, |
| 172 | + ), |
| 173 | + "tile_size": ( |
| 174 | + "INT", |
| 175 | + { |
| 176 | + "default": 512, |
| 177 | + "min": 64, |
| 178 | + "max": 4096, |
| 179 | + "step": 16, |
| 180 | + "tooltip": "Tile size in input pixels. Larger tiles are faster but use more VRAM.", |
| 181 | + }, |
| 182 | + ), |
| 183 | + "overlap": ( |
| 184 | + "INT", |
| 185 | + { |
| 186 | + "default": 32, |
| 187 | + "min": 0, |
| 188 | + "max": 1024, |
| 189 | + "step": 1, |
| 190 | + "tooltip": "Tile overlap in input pixels. Higher overlap reduces seams but increases compute.", |
| 191 | + }, |
| 192 | + ), |
| 193 | + "feather": ( |
| 194 | + "INT", |
| 195 | + { |
| 196 | + "default": 0, |
| 197 | + "min": 0, |
| 198 | + "max": 4096, |
| 199 | + "step": 1, |
| 200 | + "tooltip": "Feather width in output pixels for tile blending. 0 = auto from overlap.", |
| 201 | + }, |
| 202 | + ), |
| 203 | + "resample_method": ( |
| 204 | + [ |
| 205 | + "nearest-exact", |
| 206 | + "bilinear", |
| 207 | + "area", |
| 208 | + "bicubic", |
| 209 | + "lanczos", |
| 210 | + ], |
| 211 | + { |
| 212 | + "default": "lanczos", |
| 213 | + "tooltip": "Resampling kernel used to reach the final upscale_factor resolution.", |
| 214 | + }, |
| 215 | + ), |
| 216 | + } |
| 217 | + } |
| 218 | + |
| 219 | + RETURN_TYPES = ("IMAGE",) |
| 220 | + FUNCTION = "upscale" |
| 221 | + CATEGORY = "image/upscaling" |
| 222 | + |
| 223 | + def upscale( |
| 224 | + self, |
| 225 | + upscale_model, |
| 226 | + image, |
| 227 | + upscale_factor, |
| 228 | + tile_size, |
| 229 | + overlap, |
| 230 | + feather, |
| 231 | + resample_method, |
| 232 | + ): |
| 233 | + device = model_management.get_torch_device() |
| 234 | + |
| 235 | + scale_estimate = getattr(upscale_model, "scale", 4.0) |
| 236 | + element_size = image.element_size() |
| 237 | + |
| 238 | + memory_required = model_management.module_size(upscale_model.model) |
| 239 | + memory_required += ( |
| 240 | + tile_size * tile_size * 3 |
| 241 | + ) * element_size * max(scale_estimate, 1.0) * 384.0 |
| 242 | + memory_required += image.nelement() * element_size |
| 243 | + |
| 244 | + model_management.free_memory(memory_required, device) |
| 245 | + |
| 246 | + upscale_model.to(device) |
| 247 | + |
| 248 | + batch_size, in_h, in_w, _ = image.shape |
| 249 | + |
| 250 | + upscale_factor = float(upscale_factor) |
| 251 | + if upscale_factor < 1.0: |
| 252 | + upscale_factor = 1.0 |
| 253 | + |
| 254 | + target_height = max(1, int(round(in_h * upscale_factor))) |
| 255 | + target_width = max(1, int(round(in_w * upscale_factor))) |
| 256 | + |
| 257 | + input_image = image.movedim(-1, -3).to(device) |
| 258 | + |
| 259 | + current_tile_size = int(tile_size) |
| 260 | + minimum_tile_size = 64 |
| 261 | + |
| 262 | + upscale_result = None |
| 263 | + output_device = device |
| 264 | + |
| 265 | + oom = True |
| 266 | + last_exception = None |
| 267 | + |
| 268 | + while oom: |
| 269 | + try: |
| 270 | + steps = input_image.shape[0] * comfy.utils.get_tiled_scale_steps( |
| 271 | + input_image.shape[3], |
| 272 | + input_image.shape[2], |
| 273 | + tile_x=current_tile_size, |
| 274 | + tile_y=current_tile_size, |
| 275 | + overlap=overlap, |
| 276 | + ) |
| 277 | + progress = comfy.utils.ProgressBar(steps) |
| 278 | + |
| 279 | + upscale_result = dynamic_tiled_upscale_with_custom_feather( |
| 280 | + samples=input_image, |
| 281 | + function=lambda a: upscale_model(a), |
| 282 | + tile_size=current_tile_size, |
| 283 | + overlap=overlap, |
| 284 | + output_device=output_device, |
| 285 | + pbar=progress, |
| 286 | + feather=feather, |
| 287 | + target_height=target_height, |
| 288 | + target_width=target_width, |
| 289 | + resample_method=resample_method, |
| 290 | + ) |
| 291 | + |
| 292 | + oom = False |
| 293 | + except model_management.OOM_EXCEPTION as exception: |
| 294 | + last_exception = exception |
| 295 | + current_tile_size //= 2 |
| 296 | + if current_tile_size < minimum_tile_size: |
| 297 | + upscale_model.to("cpu") |
| 298 | + raise last_exception |
| 299 | + |
| 300 | + upscale_model.to("cpu") |
| 301 | + |
| 302 | + upscale_result = torch.clamp( |
| 303 | + upscale_result.movedim(-3, -1), min=0.0, max=1.0 |
| 304 | + ) |
| 305 | + |
| 306 | + return (upscale_result,) |
| 307 | + |
| 308 | + |
| 309 | +NODE_CLASS_MAPPINGS = { |
| 310 | + "WASTiledImageUpscaleWithModel": WASTiledImageUpscaleWithModel, |
| 311 | +} |
| 312 | + |
| 313 | +NODE_DISPLAY_NAME_MAPPINGS = { |
| 314 | + "WASTiledImageUpscaleWithModel": "Tiled Image Upscale (With Model)", |
| 315 | +} |
0 commit comments