diff --git a/comfy_extras/nodes_wan.py b/comfy_extras/nodes_wan.py index b0bd471bfb42..779b19bd5074 100644 --- a/comfy_extras/nodes_wan.py +++ b/comfy_extras/nodes_wan.py @@ -26,6 +26,7 @@ def define_schema(cls): io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Int.Input("vae_tile_size", default=0, min=0, optional=True, tooltip="VAE encode tile size, 0 means untiled (default)"), io.ClipVisionOutput.Input("clip_vision_output", optional=True), io.Image.Input("start_image", optional=True), ], @@ -37,14 +38,18 @@ def define_schema(cls): ) @classmethod - def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, clip_vision_output=None) -> io.NodeOutput: + def execute(cls, positive, negative, vae, width, height, length, batch_size, vae_tile_size=0, start_image=None, clip_vision_output=None) -> io.NodeOutput: latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device()) if start_image is not None: start_image = comfy.utils.common_upscale(start_image[:length].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1) image = torch.ones((length, height, width, start_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) * 0.5 image[:start_image.shape[0]] = start_image - concat_latent_image = vae.encode(image[:, :, :, :3]) + if vae_tile_size == 0: + concat_latent_image = vae.encode(image[:, :, :, :3]) + else: + concat_latent_image = vae.encode_tiled(image[:, :, :, :3], tile_x=vae_tile_size, tile_y=vae_tile_size, overlap=32, tile_t=256, overlap_t=8) + mask = torch.ones((1, 1, latent.shape[2], concat_latent_image.shape[-2], concat_latent_image.shape[-1]), device=start_image.device, dtype=start_image.dtype) mask[:, :, :((start_image.shape[0] - 1) // 4) + 1] = 0.0 @@ -192,6 +197,7 @@ def define_schema(cls): io.Int.Input("height", default=480, min=16, max=nodes.MAX_RESOLUTION, step=16), io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4), io.Int.Input("batch_size", default=1, min=1, max=4096), + io.Int.Input("vae_tile_size", default=0, min=0, optional=True, tooltip="VAE encode tile size, 0 means untiled (default)"), io.ClipVisionOutput.Input("clip_vision_start_image", optional=True), io.ClipVisionOutput.Input("clip_vision_end_image", optional=True), io.Image.Input("start_image", optional=True), @@ -205,7 +211,7 @@ def define_schema(cls): ) @classmethod - def execute(cls, positive, negative, vae, width, height, length, batch_size, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput: + def execute(cls, positive, negative, vae, width, height, length, batch_size, vae_tile_size=0, start_image=None, end_image=None, clip_vision_start_image=None, clip_vision_end_image=None) -> io.NodeOutput: spacial_scale = vae.spacial_compression_encode() latent = torch.zeros([batch_size, vae.latent_channels, ((length - 1) // 4) + 1, height // spacial_scale, width // spacial_scale], device=comfy.model_management.intermediate_device()) if start_image is not None: @@ -224,7 +230,11 @@ def execute(cls, positive, negative, vae, width, height, length, batch_size, sta image[-end_image.shape[0]:] = end_image mask[:, :, -end_image.shape[0]:] = 0.0 - concat_latent_image = vae.encode(image[:, :, :, :3]) + if vae_tile_size == 0: + concat_latent_image = vae.encode(image[:, :, :, :3]) + else: + concat_latent_image = vae.encode_tiled(image[:, :, :, :3], tile_x=vae_tile_size, tile_y=vae_tile_size, overlap=32, tile_t=256, overlap_t=8) + mask = mask.view(1, mask.shape[2] // 4, 4, mask.shape[3], mask.shape[4]).transpose(1, 2) positive = node_helpers.conditioning_set_values(positive, {"concat_latent_image": concat_latent_image, "concat_mask": mask}) negative = node_helpers.conditioning_set_values(negative, {"concat_latent_image": concat_latent_image, "concat_mask": mask})