Skip to content

Commit 7bbd793

Browse files
authored
Fix some models treated as having size 0 in the model cache (#6571)
## Summary This PR fixes a regression that caused the following models to be treated as having size 0 in the model cache: `(TextualInversionModelRaw, IPAdapter, LoRAModelRaw)`. Changes: - Call the correct model size calculation for all supported model types. - Log an error message if an unexpected model type is loaded, to prevent similar regressions in the future. ## QA Instructions I tested the following features and verified that no models fell back to using a size of 0 unexpectedly: - Test-to-image - Textual Inversion - LoRA - IP-Adapter - ControlNet (All tested with both SD1.5 and SDXL.) I compared the model cache switching behavior before and after this change with a large number of LoRAs (10). Since LoRAs are small compared to the main models, the changes in behaviour are minimal. Nonetheless, it makes sense to get this in for correctness. And it might make a difference for some usage patterns with limited RAM. ## Merge Plan No special instructions. ## Checklist - [x] _The PR has a short but descriptive title, suitable for a changelog_ - [x] _Tests added / updated (if applicable)_ - [x] _Documentation added / updated (if applicable)_
2 parents 0fe92cd + 414750a commit 7bbd793

File tree

4 files changed

+40
-10
lines changed

4 files changed

+40
-10
lines changed

invokeai/backend/ip_adapter/ip_adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ def to(
136136
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
137137
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
138138

139-
def calc_size(self):
140-
# workaround for circular import
141-
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
139+
def calc_size(self) -> int:
140+
# HACK(ryand): Fix this issue with circular imports.
141+
from invokeai.backend.model_manager.load.model_util import calc_module_size
142142

143-
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
143+
return calc_module_size(self._image_proj_model) + calc_module_size(self.attn_weights)
144144

145145
def _init_image_proj_model(
146146
self, state_dict: dict[str, torch.Tensor]

invokeai/backend/model_manager/load/model_cache/model_cache_default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def put(
160160
key = self._make_cache_key(key, submodel_type)
161161
if key in self._cached_models:
162162
return
163-
size = calc_model_size_by_data(model)
163+
size = calc_model_size_by_data(self.logger, model)
164164
self.make_room(size)
165165

166166
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None

invokeai/backend/model_manager/load/model_util.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,46 @@
22
"""Various utility functions needed by the loader and caching system."""
33

44
import json
5+
import logging
56
from pathlib import Path
67
from typing import Optional
78

89
import torch
9-
from diffusers import DiffusionPipeline
10+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
12+
from transformers import CLIPTokenizer
1013

14+
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
15+
from invokeai.backend.lora import LoRAModelRaw
1116
from invokeai.backend.model_manager.config import AnyModel
1217
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
18+
from invokeai.backend.textual_inversion import TextualInversionModelRaw
1319

1420

15-
def calc_model_size_by_data(model: AnyModel) -> int:
21+
def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
1622
"""Get size of a model in memory in bytes."""
23+
# TODO(ryand): We should create a CacheableModel interface for all models, and move the size calculations down to
24+
# the models themselves.
1725
if isinstance(model, DiffusionPipeline):
1826
return _calc_pipeline_by_data(model)
1927
elif isinstance(model, torch.nn.Module):
20-
return _calc_model_by_data(model)
28+
return calc_module_size(model)
2129
elif isinstance(model, IAIOnnxRuntimeModel):
2230
return _calc_onnx_model_by_data(model)
31+
elif isinstance(model, SchedulerMixin):
32+
return 0
33+
elif isinstance(model, CLIPTokenizer):
34+
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
35+
return 0
36+
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
37+
return model.calc_size()
2338
else:
39+
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
40+
# supported model types.
41+
logger.error(
42+
f"Failed to calculate model size for unexpected model type: {type(model)}. The model will be treated as "
43+
"having size 0."
44+
)
2445
return 0
2546

2647

@@ -30,11 +51,12 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
3051
for submodel_key in pipeline.components.keys():
3152
submodel = getattr(pipeline, submodel_key)
3253
if submodel is not None and isinstance(submodel, torch.nn.Module):
33-
res += _calc_model_by_data(submodel)
54+
res += calc_module_size(submodel)
3455
return res
3556

3657

37-
def _calc_model_by_data(model: torch.nn.Module) -> int:
58+
def calc_module_size(model: torch.nn.Module) -> int:
59+
"""Calculate the size (in bytes) of a torch.nn.Module."""
3860
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
3961
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
4062
mem: int = mem_params + mem_bufs # in bytes

invokeai/backend/textual_inversion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ def to(
7777
if emb is not None:
7878
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
7979

80+
def calc_size(self) -> int:
81+
"""Get the size of this model in bytes."""
82+
embedding_size = self.embedding.element_size() * self.embedding.nelement()
83+
embedding_2_size = 0
84+
if self.embedding_2 is not None:
85+
embedding_2_size = self.embedding_2.element_size() * self.embedding_2.nelement()
86+
return embedding_size + embedding_2_size
87+
8088

8189
class TextualInversionManager(BaseTextualInversionManager):
8290
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""

0 commit comments

Comments
 (0)