|
1 | 1 | # Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) |
2 | 2 |
|
| 3 | +import sqlite3 |
3 | 4 | from logging import Logger |
4 | 5 |
|
5 | 6 | from invokeai.app.services.board_image_record_storage import SqliteBoardImageRecordStorage |
6 | | -from invokeai.app.services.board_images import BoardImagesService |
| 7 | +from invokeai.app.services.board_images import BoardImagesService, BoardImagesServiceDependencies |
7 | 8 | from invokeai.app.services.board_record_storage import SqliteBoardRecordStorage |
8 | | -from invokeai.app.services.boards import BoardService |
| 9 | +from invokeai.app.services.boards import BoardService, BoardServiceDependencies |
9 | 10 | from invokeai.app.services.config import InvokeAIAppConfig |
10 | 11 | from invokeai.app.services.image_record_storage import SqliteImageRecordStorage |
11 | | -from invokeai.app.services.images import ImageService |
| 12 | +from invokeai.app.services.images import ImageService, ImageServiceDependencies |
12 | 13 | from invokeai.app.services.invocation_cache.invocation_cache_memory import MemoryInvocationCache |
13 | 14 | from invokeai.app.services.resource_name import SimpleNameService |
14 | 15 | from invokeai.app.services.session_processor.session_processor_default import DefaultSessionProcessor |
15 | 16 | from invokeai.app.services.session_queue.session_queue_sqlite import SqliteSessionQueue |
16 | | -from invokeai.app.services.shared.db import SqliteDatabase |
17 | 17 | from invokeai.app.services.urls import LocalUrlService |
18 | 18 | from invokeai.backend.util.logging import InvokeAILogger |
19 | 19 | from invokeai.version.invokeai_version import __version__ |
|
29 | 29 | from ..services.model_manager_service import ModelManagerService |
30 | 30 | from ..services.processor import DefaultInvocationProcessor |
31 | 31 | from ..services.sqlite import SqliteItemStorage |
| 32 | +from ..services.thread import lock |
32 | 33 | from .events import FastAPIEventService |
33 | 34 |
|
34 | 35 |
|
@@ -62,64 +63,100 @@ def initialize(config: InvokeAIAppConfig, event_handler_id: int, logger: Logger |
62 | 63 | logger.info(f"Root directory = {str(config.root_path)}") |
63 | 64 | logger.debug(f"Internet connectivity is {config.internet_available}") |
64 | 65 |
|
| 66 | + events = FastAPIEventService(event_handler_id) |
| 67 | + |
65 | 68 | output_folder = config.output_path |
66 | 69 |
|
67 | | - db = SqliteDatabase(config, logger) |
| 70 | + # TODO: build a file/path manager? |
| 71 | + if config.use_memory_db: |
| 72 | + db_location = ":memory:" |
| 73 | + else: |
| 74 | + db_path = config.db_path |
| 75 | + db_path.parent.mkdir(parents=True, exist_ok=True) |
| 76 | + db_location = str(db_path) |
68 | 77 |
|
69 | | - configuration = config |
70 | | - logger = logger |
| 78 | + logger.info(f"Using database at {db_location}") |
| 79 | + db_conn = sqlite3.connect(db_location, check_same_thread=False) # TODO: figure out a better threading solution |
| 80 | + |
| 81 | + if config.log_sql: |
| 82 | + db_conn.set_trace_callback(print) |
| 83 | + db_conn.execute("PRAGMA foreign_keys = ON;") |
| 84 | + |
| 85 | + graph_execution_manager = SqliteItemStorage[GraphExecutionState]( |
| 86 | + conn=db_conn, table_name="graph_executions", lock=lock |
| 87 | + ) |
71 | 88 |
|
72 | | - board_image_records = SqliteBoardImageRecordStorage(db=db) |
73 | | - board_images = BoardImagesService() |
74 | | - board_records = SqliteBoardRecordStorage(db=db) |
75 | | - boards = BoardService() |
76 | | - events = FastAPIEventService(event_handler_id) |
77 | | - graph_execution_manager = SqliteItemStorage[GraphExecutionState](db=db, table_name="graph_executions") |
78 | | - graph_library = SqliteItemStorage[LibraryGraph](db=db, table_name="graphs") |
79 | | - image_files = DiskImageFileStorage(f"{output_folder}/images") |
80 | | - image_records = SqliteImageRecordStorage(db=db) |
81 | | - images = ImageService() |
82 | | - invocation_cache = MemoryInvocationCache(max_cache_size=config.node_cache_size) |
83 | | - latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) |
84 | | - model_manager = ModelManagerService(config, logger) |
85 | | - names = SimpleNameService() |
86 | | - performance_statistics = InvocationStatsService() |
87 | | - processor = DefaultInvocationProcessor() |
88 | | - queue = MemoryInvocationQueue() |
89 | | - session_processor = DefaultSessionProcessor() |
90 | | - session_queue = SqliteSessionQueue(db=db) |
91 | 89 | urls = LocalUrlService() |
| 90 | + image_record_storage = SqliteImageRecordStorage(conn=db_conn, lock=lock) |
| 91 | + image_file_storage = DiskImageFileStorage(f"{output_folder}/images") |
| 92 | + names = SimpleNameService() |
| 93 | + latents = ForwardCacheLatentsStorage(DiskLatentsStorage(f"{output_folder}/latents")) |
| 94 | + |
| 95 | + board_record_storage = SqliteBoardRecordStorage(conn=db_conn, lock=lock) |
| 96 | + board_image_record_storage = SqliteBoardImageRecordStorage(conn=db_conn, lock=lock) |
| 97 | + |
| 98 | + boards = BoardService( |
| 99 | + services=BoardServiceDependencies( |
| 100 | + board_image_record_storage=board_image_record_storage, |
| 101 | + board_record_storage=board_record_storage, |
| 102 | + image_record_storage=image_record_storage, |
| 103 | + url=urls, |
| 104 | + logger=logger, |
| 105 | + ) |
| 106 | + ) |
| 107 | + |
| 108 | + board_images = BoardImagesService( |
| 109 | + services=BoardImagesServiceDependencies( |
| 110 | + board_image_record_storage=board_image_record_storage, |
| 111 | + board_record_storage=board_record_storage, |
| 112 | + image_record_storage=image_record_storage, |
| 113 | + url=urls, |
| 114 | + logger=logger, |
| 115 | + ) |
| 116 | + ) |
| 117 | + |
| 118 | + images = ImageService( |
| 119 | + services=ImageServiceDependencies( |
| 120 | + board_image_record_storage=board_image_record_storage, |
| 121 | + image_record_storage=image_record_storage, |
| 122 | + image_file_storage=image_file_storage, |
| 123 | + url=urls, |
| 124 | + logger=logger, |
| 125 | + names=names, |
| 126 | + graph_execution_manager=graph_execution_manager, |
| 127 | + ) |
| 128 | + ) |
92 | 129 |
|
93 | 130 | services = InvocationServices( |
94 | | - board_image_records=board_image_records, |
95 | | - board_images=board_images, |
96 | | - board_records=board_records, |
97 | | - boards=boards, |
98 | | - configuration=configuration, |
| 131 | + model_manager=ModelManagerService(config, logger), |
99 | 132 | events=events, |
100 | | - graph_execution_manager=graph_execution_manager, |
101 | | - graph_library=graph_library, |
102 | | - image_files=image_files, |
103 | | - image_records=image_records, |
104 | | - images=images, |
105 | | - invocation_cache=invocation_cache, |
106 | 133 | latents=latents, |
| 134 | + images=images, |
| 135 | + boards=boards, |
| 136 | + board_images=board_images, |
| 137 | + queue=MemoryInvocationQueue(), |
| 138 | + graph_library=SqliteItemStorage[LibraryGraph](conn=db_conn, lock=lock, table_name="graphs"), |
| 139 | + graph_execution_manager=graph_execution_manager, |
| 140 | + processor=DefaultInvocationProcessor(), |
| 141 | + configuration=config, |
| 142 | + performance_statistics=InvocationStatsService(graph_execution_manager), |
107 | 143 | logger=logger, |
108 | | - model_manager=model_manager, |
109 | | - names=names, |
110 | | - performance_statistics=performance_statistics, |
111 | | - processor=processor, |
112 | | - queue=queue, |
113 | | - session_processor=session_processor, |
114 | | - session_queue=session_queue, |
115 | | - urls=urls, |
| 144 | + session_queue=SqliteSessionQueue(conn=db_conn, lock=lock), |
| 145 | + session_processor=DefaultSessionProcessor(), |
| 146 | + invocation_cache=MemoryInvocationCache(max_cache_size=config.node_cache_size), |
116 | 147 | ) |
117 | 148 |
|
118 | 149 | create_system_graphs(services.graph_library) |
119 | 150 |
|
120 | 151 | ApiDependencies.invoker = Invoker(services) |
121 | 152 |
|
122 | | - db.clean() |
| 153 | + try: |
| 154 | + lock.acquire() |
| 155 | + db_conn.execute("VACUUM;") |
| 156 | + db_conn.commit() |
| 157 | + logger.info("Cleaned database") |
| 158 | + finally: |
| 159 | + lock.release() |
123 | 160 |
|
124 | 161 | @staticmethod |
125 | 162 | def shutdown(): |
|
0 commit comments