Skip to content

Commit b4cf78a

Browse files
fix: make DA Pipeline a subclass of RawModel
1 parent 18f89ed commit b4cf78a

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

invokeai/backend/image_util/depth_anything/depth_anything_pipeline.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
from typing import cast
1+
from typing import Optional, cast
22

3+
import torch
34
from PIL import Image
45
from transformers.pipelines import DepthEstimationPipeline
56

7+
from invokeai.backend.raw_model import RawModel
68

7-
class DepthAnythingPipeline:
9+
10+
class DepthAnythingPipeline(RawModel):
811
"""Custom wrapper for the Depth Estimation pipeline from transformers adding compatibility
912
for Invoke's Model Management System"""
1013

@@ -20,6 +23,9 @@ def generate_depth(self, image: Image.Image, resolution: int = 512):
2023
depth_map = depth_map.resize((resolution, new_height))
2124
return depth_map
2225

26+
def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None):
27+
pass
28+
2329
def calc_size(self) -> int:
2430
from invokeai.backend.model_manager.load.model_util import calc_module_size
2531

invokeai/backend/model_manager/config.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,13 @@
3131
from typing_extensions import Annotated, Any, Dict
3232

3333
from invokeai.app.util.misc import uuid_string
34-
from invokeai.backend.image_util.depth_anything.depth_anything_pipeline import DepthAnythingPipeline
3534
from invokeai.backend.model_hash.hash_validator import validate_hash
3635
from invokeai.backend.raw_model import RawModel
3736
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
3837

3938
# ModelMixin is the base class for all diffusers and transformers models
4039
# RawModel is the InvokeAI wrapper class for ip_adapters, loras, textual_inversion and onnx runtime
41-
AnyModel = Union[
42-
ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline, DepthAnythingPipeline
43-
]
40+
AnyModel = Union[ModelMixin, RawModel, torch.nn.Module, Dict[str, torch.Tensor], diffusers.DiffusionPipeline]
4441

4542

4643
class InvalidModelConfigException(Exception):

0 commit comments

Comments
 (0)