Skip to content

Commit 7d86f00

Browse files
feat(mm): implement working memory estimation for VAE encode for all models
Tell the model manager that we need some extra working memory for VAE encoding operations to prevent OOMs. See previous commit for investigation and determination of the magic numbers used. This safety measure is especially relevant now that we have FLUX Kontext and may be encoding rather large ref images. Without the working memory estimation we can OOM as we prepare for denoising. See #8405 for an example of this issue on a very low VRAM system. It's possible we can have the same issue on any GPU, though - just a matter of hitting the right combination of models loaded.
1 parent 7785061 commit 7d86f00

File tree

5 files changed

+93
-13
lines changed

5 files changed

+93
-13
lines changed

invokeai/app/invocations/cogview4_image_to_latents.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,19 @@ class CogView4ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
3636
image: ImageField = InputField(description="The image to encode.")
3737
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
3838

39+
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
40+
"""Estimate the working memory required by the invocation in bytes."""
41+
# Encode operations use approximately 50% of the memory required for decode operations
42+
h = image_tensor.shape[-2]
43+
w = image_tensor.shape[-1]
44+
element_size = next(vae.parameters()).element_size()
45+
scaling_constant = 1100 # 50% of decode scaling constant (2200)
46+
working_memory = h * w * element_size * scaling_constant
47+
return int(working_memory)
48+
3949
@staticmethod
40-
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
41-
with vae_info as vae:
50+
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
51+
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
4252
assert isinstance(vae, AutoencoderKL)
4353

4454
vae.disable_tiling()
@@ -62,7 +72,10 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
6272
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
6373

6474
vae_info = context.models.load(self.vae.vae)
65-
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
75+
assert isinstance(vae_info.model, AutoencoderKL)
76+
77+
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
78+
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory)
6679

6780
latents = latents.to("cpu")
6881
name = context.tensors.save(tensor=latents)

invokeai/app/invocations/flux_vae_encode.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,24 @@ class FluxVaeEncodeInvocation(BaseInvocation):
3535
input=Input.Connection,
3636
)
3737

38+
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoEncoder) -> int:
39+
"""Estimate the working memory required by the invocation in bytes."""
40+
# Encode operations use approximately 50% of the memory required for decode operations
41+
h = image_tensor.shape[-2]
42+
w = image_tensor.shape[-1]
43+
element_size = next(vae.parameters()).element_size()
44+
scaling_constant = 1100 # 50% of decode scaling constant (2200)
45+
working_memory = h * w * element_size * scaling_constant
46+
return int(working_memory)
47+
3848
@staticmethod
39-
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
49+
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
4050
# TODO(ryand): Expose seed parameter at the invocation level.
4151
# TODO(ryand): Write a util function for generating random tensors that is consistent across devices / dtypes.
4252
# There's a starting point in get_noise(...), but it needs to be extracted and generalized. This function
4353
# should be used for VAE encode sampling.
4454
generator = torch.Generator(device=TorchDevice.choose_torch_device()).manual_seed(0)
45-
with vae_info as vae:
55+
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
4656
assert isinstance(vae, AutoEncoder)
4757
vae_dtype = next(iter(vae.parameters())).dtype
4858
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)
@@ -60,7 +70,8 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
6070
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
6171

6272
context.util.signal_progress("Running VAE")
63-
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
73+
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
74+
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory)
6475

6576
latents = latents.to("cpu")
6677
name = context.tensors.save(tensor=latents)

invokeai/app/invocations/image_to_latents.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,43 @@ class ImageToLatentsInvocation(BaseInvocation):
5252
tile_size: int = InputField(default=0, multiple_of=8, description=FieldDescriptions.vae_tile_size)
5353
fp32: bool = InputField(default=False, description=FieldDescriptions.fp32)
5454

55+
def _estimate_working_memory(
56+
self, image_tensor: torch.Tensor, use_tiling: bool, vae: AutoencoderKL | AutoencoderTiny
57+
) -> int:
58+
"""Estimate the working memory required by the invocation in bytes."""
59+
# Encode operations use approximately 50% of the memory required for decode operations
60+
element_size = 4 if self.fp32 else 2
61+
scaling_constant = 1100 # 50% of decode scaling constant (2200)
62+
63+
if use_tiling:
64+
tile_size = self.tile_size
65+
if tile_size == 0:
66+
tile_size = vae.tile_sample_min_size
67+
assert isinstance(tile_size, int)
68+
h = tile_size
69+
w = tile_size
70+
working_memory = h * w * element_size * scaling_constant
71+
72+
# We add 25% to the working memory estimate when tiling is enabled to account for factors like tile overlap
73+
# and number of tiles. We could make this more precise in the future, but this should be good enough for
74+
# most use cases.
75+
working_memory = working_memory * 1.25
76+
else:
77+
h = image_tensor.shape[-2]
78+
w = image_tensor.shape[-1]
79+
working_memory = h * w * element_size * scaling_constant
80+
81+
if self.fp32:
82+
# If we are running in FP32, then we should account for the likely increase in model size (~250MB).
83+
working_memory += 250 * 2**20
84+
85+
return int(working_memory)
86+
5587
@staticmethod
5688
def vae_encode(
57-
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0
89+
vae_info: LoadedModel, upcast: bool, tiled: bool, image_tensor: torch.Tensor, tile_size: int = 0, estimated_working_memory: int = 0
5890
) -> torch.Tensor:
59-
with vae_info as vae:
91+
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
6092
assert isinstance(vae, (AutoencoderKL, AutoencoderTiny))
6193
orig_dtype = vae.dtype
6294
if upcast:
@@ -113,14 +145,18 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
113145
image = context.images.get_pil(self.image.image_name)
114146

115147
vae_info = context.models.load(self.vae.vae)
148+
assert isinstance(vae_info.model, (AutoencoderKL, AutoencoderTiny))
116149

117150
image_tensor = image_resized_to_grid_as_tensor(image.convert("RGB"))
118151
if image_tensor.dim() == 3:
119152
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
120153

154+
use_tiling = self.tiled or context.config.get().force_tiled_decode
155+
estimated_working_memory = self._estimate_working_memory(image_tensor, use_tiling, vae_info.model)
156+
121157
context.util.signal_progress("Running VAE encoder")
122158
latents = self.vae_encode(
123-
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size
159+
vae_info=vae_info, upcast=self.fp32, tiled=self.tiled, image_tensor=image_tensor, tile_size=self.tile_size, estimated_working_memory=estimated_working_memory
124160
)
125161

126162
latents = latents.to("cpu")

invokeai/app/invocations/sd3_image_to_latents.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,19 @@ class SD3ImageToLatentsInvocation(BaseInvocation, WithMetadata, WithBoard):
3232
image: ImageField = InputField(description="The image to encode")
3333
vae: VAEField = InputField(description=FieldDescriptions.vae, input=Input.Connection)
3434

35+
def _estimate_working_memory(self, image_tensor: torch.Tensor, vae: AutoencoderKL) -> int:
36+
"""Estimate the working memory required by the invocation in bytes."""
37+
# Encode operations use approximately 50% of the memory required for decode operations
38+
h = image_tensor.shape[-2]
39+
w = image_tensor.shape[-1]
40+
element_size = next(vae.parameters()).element_size()
41+
scaling_constant = 1100 # 50% of decode scaling constant (2200)
42+
working_memory = h * w * element_size * scaling_constant
43+
return int(working_memory)
44+
3545
@staticmethod
36-
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor) -> torch.Tensor:
37-
with vae_info as vae:
46+
def vae_encode(vae_info: LoadedModel, image_tensor: torch.Tensor, estimated_working_memory: int) -> torch.Tensor:
47+
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
3848
assert isinstance(vae, AutoencoderKL)
3949

4050
vae.disable_tiling()
@@ -58,7 +68,10 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
5868
image_tensor = einops.rearrange(image_tensor, "c h w -> 1 c h w")
5969

6070
vae_info = context.models.load(self.vae.vae)
61-
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor)
71+
assert isinstance(vae_info.model, AutoencoderKL)
72+
73+
estimated_working_memory = self._estimate_working_memory(image_tensor, vae_info.model)
74+
latents = self.vae_encode(vae_info=vae_info, image_tensor=image_tensor, estimated_working_memory=estimated_working_memory)
6275

6376
latents = latents.to("cpu")
6477
name = context.tensors.save(tensor=latents)

invokeai/backend/flux/extensions/kontext_extension.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,14 @@ def _prepare_kontext(self) -> tuple[torch.Tensor, torch.Tensor]:
131131

132132
# Continue with VAE encoding
133133
# Don't sample from the distribution for reference images - use the mean (matching ComfyUI)
134-
with vae_info as vae:
134+
# Estimate working memory for encode operation (50% of decode memory requirements)
135+
h = image_tensor.shape[-2]
136+
w = image_tensor.shape[-1]
137+
element_size = next(vae_info.model.parameters()).element_size()
138+
scaling_constant = 1100 # 50% of decode scaling constant (2200)
139+
estimated_working_memory = int(h * w * element_size * scaling_constant)
140+
141+
with vae_info.model_on_device(working_mem_bytes=estimated_working_memory) as (_, vae):
135142
assert isinstance(vae, AutoEncoder)
136143
vae_dtype = next(iter(vae.parameters())).dtype
137144
image_tensor = image_tensor.to(device=TorchDevice.choose_torch_device(), dtype=vae_dtype)

0 commit comments

Comments
 (0)