Skip to content

Commit 414750a

Browse files
committed
Update calc_model_size_by_data(...) to handle all expected model types, and to log an error if an unexpected model type is received.
1 parent 0fe92cd commit 414750a

File tree

4 files changed

+40
-10
lines changed

4 files changed

+40
-10
lines changed

invokeai/backend/ip_adapter/ip_adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,11 @@ def to(
136136
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
137137
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
138138

139-
def calc_size(self):
140-
# workaround for circular import
141-
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
139+
def calc_size(self) -> int:
140+
# HACK(ryand): Fix this issue with circular imports.
141+
from invokeai.backend.model_manager.load.model_util import calc_module_size
142142

143-
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
143+
return calc_module_size(self._image_proj_model) + calc_module_size(self.attn_weights)
144144

145145
def _init_image_proj_model(
146146
self, state_dict: dict[str, torch.Tensor]

invokeai/backend/model_manager/load/model_cache/model_cache_default.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def put(
160160
key = self._make_cache_key(key, submodel_type)
161161
if key in self._cached_models:
162162
return
163-
size = calc_model_size_by_data(model)
163+
size = calc_model_size_by_data(self.logger, model)
164164
self.make_room(size)
165165

166166
state_dict = model.state_dict() if isinstance(model, torch.nn.Module) else None

invokeai/backend/model_manager/load/model_util.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,46 @@
22
"""Various utility functions needed by the loader and caching system."""
33

44
import json
5+
import logging
56
from pathlib import Path
67
from typing import Optional
78

89
import torch
9-
from diffusers import DiffusionPipeline
10+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
11+
from diffusers.schedulers.scheduling_utils import SchedulerMixin
12+
from transformers import CLIPTokenizer
1013

14+
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
15+
from invokeai.backend.lora import LoRAModelRaw
1116
from invokeai.backend.model_manager.config import AnyModel
1217
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
18+
from invokeai.backend.textual_inversion import TextualInversionModelRaw
1319

1420

15-
def calc_model_size_by_data(model: AnyModel) -> int:
21+
def calc_model_size_by_data(logger: logging.Logger, model: AnyModel) -> int:
1622
"""Get size of a model in memory in bytes."""
23+
# TODO(ryand): We should create a CacheableModel interface for all models, and move the size calculations down to
24+
# the models themselves.
1725
if isinstance(model, DiffusionPipeline):
1826
return _calc_pipeline_by_data(model)
1927
elif isinstance(model, torch.nn.Module):
20-
return _calc_model_by_data(model)
28+
return calc_module_size(model)
2129
elif isinstance(model, IAIOnnxRuntimeModel):
2230
return _calc_onnx_model_by_data(model)
31+
elif isinstance(model, SchedulerMixin):
32+
return 0
33+
elif isinstance(model, CLIPTokenizer):
34+
# TODO(ryand): Accurately calculate the tokenizer's size. It's small enough that it shouldn't matter for now.
35+
return 0
36+
elif isinstance(model, (TextualInversionModelRaw, IPAdapter, LoRAModelRaw)):
37+
return model.calc_size()
2338
else:
39+
# TODO(ryand): Promote this from a log to an exception once we are confident that we are handling all of the
40+
# supported model types.
41+
logger.error(
42+
f"Failed to calculate model size for unexpected model type: {type(model)}. The model will be treated as "
43+
"having size 0."
44+
)
2445
return 0
2546

2647

@@ -30,11 +51,12 @@ def _calc_pipeline_by_data(pipeline: DiffusionPipeline) -> int:
3051
for submodel_key in pipeline.components.keys():
3152
submodel = getattr(pipeline, submodel_key)
3253
if submodel is not None and isinstance(submodel, torch.nn.Module):
33-
res += _calc_model_by_data(submodel)
54+
res += calc_module_size(submodel)
3455
return res
3556

3657

37-
def _calc_model_by_data(model: torch.nn.Module) -> int:
58+
def calc_module_size(model: torch.nn.Module) -> int:
59+
"""Calculate the size (in bytes) of a torch.nn.Module."""
3860
mem_params = sum([param.nelement() * param.element_size() for param in model.parameters()])
3961
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in model.buffers()])
4062
mem: int = mem_params + mem_bufs # in bytes

invokeai/backend/textual_inversion.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,14 @@ def to(
7777
if emb is not None:
7878
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
7979

80+
def calc_size(self) -> int:
81+
"""Get the size of this model in bytes."""
82+
embedding_size = self.embedding.element_size() * self.embedding.nelement()
83+
embedding_2_size = 0
84+
if self.embedding_2 is not None:
85+
embedding_2_size = self.embedding_2.element_size() * self.embedding_2.nelement()
86+
return embedding_size + embedding_2_size
87+
8088

8189
class TextualInversionManager(BaseTextualInversionManager):
8290
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""

0 commit comments

Comments
 (0)