Skip to content

Commit d374680

Browse files
committed
Merge branch 'main' into copilot/add-command-line-script-orphaned-models
2 parents 4eeb3dd + 8cf83a9 commit d374680

File tree

17 files changed

+637
-65
lines changed

17 files changed

+637
-65
lines changed

invokeai/app/api/routers/model_manager.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,28 @@ async def list_model_records(
149149
return ModelsList(models=found_models)
150150

151151

152+
@model_manager_router.get(
153+
"/missing",
154+
operation_id="list_missing_models",
155+
responses={200: {"description": "List of models with missing files"}},
156+
)
157+
async def list_missing_models() -> ModelsList:
158+
"""Get models whose files are missing from disk.
159+
160+
These are models that have database entries but their corresponding
161+
weight files have been deleted externally (not via Model Manager).
162+
"""
163+
record_store = ApiDependencies.invoker.services.model_manager.store
164+
models_path = ApiDependencies.invoker.services.configuration.models_path
165+
166+
missing_models: list[AnyModelConfig] = []
167+
for model_config in record_store.all_models():
168+
if not (models_path / model_config.path).resolve().exists():
169+
missing_models.append(model_config)
170+
171+
return ModelsList(models=missing_models)
172+
173+
152174
@model_manager_router.get(
153175
"/get_by_attrs",
154176
operation_id="get_model_records_by_attrs",

invokeai/app/services/shared/graph.py

Lines changed: 99 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -124,38 +124,36 @@ def is_any(t: Any) -> bool:
124124

125125

126126
def are_connection_types_compatible(from_type: Any, to_type: Any) -> bool:
127-
if not from_type:
128-
return False
129-
if not to_type:
127+
if not from_type or not to_type:
130128
return False
131129

132-
# TODO: this is pretty forgiving on generic types. Clean that up (need to handle optionals and such)
133-
if from_type and to_type:
134-
# Ports are compatible
135-
if from_type == to_type or is_any(from_type) or is_any(to_type):
136-
return True
130+
# Ports are compatible
131+
if from_type == to_type or is_any(from_type) or is_any(to_type):
132+
return True
137133

138-
if from_type in get_args(to_type):
139-
return True
134+
if from_type in get_args(to_type):
135+
return True
140136

141-
if to_type in get_args(from_type):
142-
return True
137+
if to_type in get_args(from_type):
138+
return True
143139

144-
# allow int -> float, pydantic will cast for us
145-
if from_type is int and to_type is float:
146-
return True
140+
# allow int -> float, pydantic will cast for us
141+
if from_type is int and to_type is float:
142+
return True
147143

148-
# allow int|float -> str, pydantic will cast for us
149-
if (from_type is int or from_type is float) and to_type is str:
150-
return True
144+
# allow int|float -> str, pydantic will cast for us
145+
if (from_type is int or from_type is float) and to_type is str:
146+
return True
151147

152-
# if not issubclass(from_type, to_type):
153-
if not is_union_subtype(from_type, to_type):
154-
return False
155-
else:
156-
return False
148+
# Prefer issubclass when both are real classes
149+
try:
150+
if isinstance(from_type, type) and isinstance(to_type, type):
151+
return issubclass(from_type, to_type)
152+
except TypeError:
153+
pass
157154

158-
return True
155+
# Union-to-Union (or Union-to-non-Union) handling
156+
return is_union_subtype(from_type, to_type)
159157

160158

161159
def are_connections_compatible(
@@ -654,6 +652,9 @@ def _is_iterator_connection_valid(
654652
if new_output is not None:
655653
outputs.append(new_output)
656654

655+
if len(inputs) == 0:
656+
return "Iterator must have a collection input edge"
657+
657658
# Only one input is allowed for iterators
658659
if len(inputs) > 1:
659660
return "Iterator may only have one input edge"
@@ -675,9 +676,13 @@ def _is_iterator_connection_valid(
675676

676677
# Collector input type must match all iterator output types
677678
if isinstance(input_node, CollectInvocation):
679+
collector_inputs = self._get_input_edges(input_node.id, ITEM_FIELD)
680+
if len(collector_inputs) == 0:
681+
return "Iterator input collector must have at least one item input edge"
682+
678683
# Traverse the graph to find the first collector input edge. Collectors validate that their collection
679684
# inputs are all of the same type, so we can use the first input edge to determine the collector's type
680-
first_collector_input_edge = self._get_input_edges(input_node.id, ITEM_FIELD)[0]
685+
first_collector_input_edge = collector_inputs[0]
681686
first_collector_input_type = get_output_field_type(
682687
self.get_node(first_collector_input_edge.source.node_id), first_collector_input_edge.source.field
683688
)
@@ -751,21 +756,12 @@ def nx_graph(self) -> nx.DiGraph:
751756
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
752757
return g
753758

754-
def nx_graph_with_data(self) -> nx.DiGraph:
755-
"""Returns a NetworkX DiGraph representing the data and layout of this graph"""
756-
g = nx.DiGraph()
757-
g.add_nodes_from(list(self.nodes.items()))
758-
g.add_edges_from({(e.source.node_id, e.destination.node_id) for e in self.edges})
759-
return g
760-
761759
def nx_graph_flat(self, nx_graph: Optional[nx.DiGraph] = None) -> nx.DiGraph:
762760
"""Returns a flattened NetworkX DiGraph, including all subgraphs (but not with iterations expanded)"""
763761
g = nx_graph or nx.DiGraph()
764762

765763
# Add all nodes from this graph except graph/iteration nodes
766-
g.add_nodes_from([n.id for n in self.nodes.values() if not isinstance(n, IterateInvocation)])
767-
768-
# TODO: figure out if iteration nodes need to be expanded
764+
g.add_nodes_from([n.id for n in self.nodes.values()])
769765

770766
unique_edges = {(e.source.node_id, e.destination.node_id) for e in self.edges}
771767
g.add_edges_from(unique_edges)
@@ -816,10 +812,57 @@ class GraphExecutionState(BaseModel):
816812
# Optional priority; others follow in name order
817813
ready_order: list[str] = Field(default_factory=list)
818814
indegree: dict[str, int] = Field(default_factory=dict, description="Remaining unmet input count for exec nodes")
815+
_iteration_path_cache: dict[str, tuple[int, ...]] = PrivateAttr(default_factory=dict)
819816

820817
def _type_key(self, node_obj: BaseInvocation) -> str:
821818
return node_obj.__class__.__name__
822819

820+
def _get_iteration_path(self, exec_node_id: str) -> tuple[int, ...]:
821+
"""Best-effort outer->inner iteration indices for an execution node, stopping at collectors."""
822+
cached = self._iteration_path_cache.get(exec_node_id)
823+
if cached is not None:
824+
return cached
825+
826+
# Only prepared execution nodes participate; otherwise treat as non-iterated.
827+
source_node_id = self.prepared_source_mapping.get(exec_node_id)
828+
if source_node_id is None:
829+
self._iteration_path_cache[exec_node_id] = ()
830+
return ()
831+
832+
# Source-graph iterator ancestry, with edges into collectors removed so iteration context doesn't leak.
833+
it_g = self._iterator_graph(self.graph.nx_graph())
834+
iterator_sources = [
835+
n for n in nx.ancestors(it_g, source_node_id) if isinstance(self.graph.get_node(n), IterateInvocation)
836+
]
837+
838+
# Order iterators outer->inner via topo order of the iterator graph.
839+
topo = list(nx.topological_sort(it_g))
840+
topo_index = {n: i for i, n in enumerate(topo)}
841+
iterator_sources.sort(key=lambda n: topo_index.get(n, 0))
842+
843+
# Map iterator source nodes to the prepared iterator exec nodes that are ancestors of exec_node_id.
844+
eg = self.execution_graph.nx_graph()
845+
path: list[int] = []
846+
for it_src in iterator_sources:
847+
prepared = self.source_prepared_mapping.get(it_src)
848+
if not prepared:
849+
continue
850+
it_exec = next((p for p in prepared if nx.has_path(eg, p, exec_node_id)), None)
851+
if it_exec is None:
852+
continue
853+
it_node = self.execution_graph.nodes.get(it_exec)
854+
if isinstance(it_node, IterateInvocation):
855+
path.append(it_node.index)
856+
857+
# If this exec node is itself an iterator, include its own index as the innermost element.
858+
node_obj = self.execution_graph.nodes.get(exec_node_id)
859+
if isinstance(node_obj, IterateInvocation):
860+
path.append(node_obj.index)
861+
862+
result = tuple(path)
863+
self._iteration_path_cache[exec_node_id] = result
864+
return result
865+
823866
def _queue_for(self, cls_name: str) -> Deque[str]:
824867
q = self._ready_queues.get(cls_name)
825868
if q is None:
@@ -843,7 +886,15 @@ def _enqueue_if_ready(self, nid: str) -> None:
843886
if self.indegree[nid] != 0 or nid in self.executed:
844887
return
845888
node_obj = self.execution_graph.nodes[nid]
846-
self._queue_for(self._type_key(node_obj)).append(nid)
889+
q = self._queue_for(self._type_key(node_obj))
890+
nid_path = self._get_iteration_path(nid)
891+
# Insert in lexicographic outer->inner order; preserve FIFO for equal paths.
892+
for i, existing in enumerate(q):
893+
if self._get_iteration_path(existing) > nid_path:
894+
q.insert(i, nid)
895+
break
896+
else:
897+
q.append(nid)
847898

848899
model_config = ConfigDict(
849900
json_schema_extra={
@@ -1083,12 +1134,12 @@ def no_unexecuted_iter_ancestors(n: str) -> bool:
10831134

10841135
# Select the correct prepared parents for each iteration
10851136
# For every iterator, the parent must either not be a child of that iterator, or must match the prepared iteration for that iterator
1086-
# TODO: Handle a node mapping to none
10871137
eg = self.execution_graph.nx_graph_flat()
10881138
prepared_parent_mappings = [
10891139
[(n, self._get_iteration_node(n, g, eg, it)) for n in next_node_parents]
10901140
for it in iterator_node_prepared_combinations
10911141
] # type: ignore
1142+
prepared_parent_mappings = [m for m in prepared_parent_mappings if all(p[1] is not None for p in m)]
10921143

10931144
# Create execution node for each iteration
10941145
for iteration_mappings in prepared_parent_mappings:
@@ -1110,15 +1161,17 @@ def _get_iteration_node(
11101161
if len(prepared_nodes) == 1:
11111162
return next(iter(prepared_nodes))
11121163

1113-
# Check if the requested node is an iterator
1114-
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
1115-
if prepared_iterator is not None:
1116-
return prepared_iterator
1117-
11181164
# Filter to only iterator nodes that are a parent of the specified node, in tuple format (prepared, source)
11191165
iterator_source_node_mapping = [(n, self.prepared_source_mapping[n]) for n in prepared_iterator_nodes]
11201166
parent_iterators = [itn for itn in iterator_source_node_mapping if nx.has_path(graph, itn[1], source_node_id)]
11211167

1168+
# If the requested node is an iterator, only accept it if it is compatible with all parent iterators
1169+
prepared_iterator = next((n for n in prepared_nodes if n in prepared_iterator_nodes), None)
1170+
if prepared_iterator is not None:
1171+
if all(nx.has_path(execution_graph, pit[0], prepared_iterator) for pit in parent_iterators):
1172+
return prepared_iterator
1173+
return None
1174+
11221175
return next(
11231176
(n for n in prepared_nodes if all(nx.has_path(execution_graph, pit[0], n) for pit in parent_iterators)),
11241177
None,
@@ -1156,11 +1209,10 @@ def _prepare_inputs(self, node: BaseInvocation):
11561209
# Inputs must be deep-copied, else if a node mutates the object, other nodes that get the same input
11571210
# will see the mutation.
11581211
if isinstance(node, CollectInvocation):
1159-
output_collection = [
1160-
copydeep(getattr(self.results[edge.source.node_id], edge.source.field))
1161-
for edge in input_edges
1162-
if edge.destination.field == ITEM_FIELD
1163-
]
1212+
item_edges = [e for e in input_edges if e.destination.field == ITEM_FIELD]
1213+
item_edges.sort(key=lambda e: (self._get_iteration_path(e.source.node_id), e.source.node_id))
1214+
1215+
output_collection = [copydeep(getattr(self.results[e.source.node_id], e.source.field)) for e in item_edges]
11641216
node.collection = output_collection
11651217
else:
11661218
for edge in input_edges:

invokeai/frontend/web/public/locales/en.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -974,6 +974,8 @@
974974
"loraModels": "LoRAs",
975975
"main": "Main",
976976
"metadata": "Metadata",
977+
"missingFiles": "Missing Files",
978+
"missingFilesTooltip": "Model files are missing from disk",
977979
"model": "Model",
978980
"modelConversionFailed": "Model Conversion Failed",
979981
"modelConverted": "Model Converted",

invokeai/frontend/web/src/features/modelManagerV2/models.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ import {
2222
} from 'services/api/types';
2323
import { objectEntries } from 'tsafe';
2424

25-
import type { FilterableModelType } from './store/modelManagerV2Slice';
25+
import type { ModelCategoryType } from './store/modelManagerV2Slice';
2626

2727
export type ModelCategoryData = {
28-
category: FilterableModelType;
28+
category: ModelCategoryType;
2929
i18nKey: string;
3030
filter: (config: AnyModelConfig) => boolean;
3131
};
3232

33-
export const MODEL_CATEGORIES: Record<FilterableModelType, ModelCategoryData> = {
33+
export const MODEL_CATEGORIES: Record<ModelCategoryType, ModelCategoryData> = {
3434
unknown: {
3535
category: 'unknown',
3636
i18nKey: 'common.unknown',

invokeai/frontend/web/src/features/modelManagerV2/store/modelManagerV2Slice.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ import { zModelType } from 'features/nodes/types/common';
77
import { assert } from 'tsafe';
88
import z from 'zod';
99

10-
const zFilterableModelType = zModelType.exclude(['onnx']).or(z.literal('refiner'));
10+
const zModelCategoryType = zModelType.exclude(['onnx']).or(z.literal('refiner'));
11+
export type ModelCategoryType = z.infer<typeof zModelCategoryType>;
12+
13+
const zFilterableModelType = zModelCategoryType.or(z.literal('missing'));
1114
export type FilterableModelType = z.infer<typeof zFilterableModelType>;
1215

1316
const zModelManagerState = z.object({
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import type { PropsWithChildren } from 'react';
2+
import { createContext, useContext, useMemo } from 'react';
3+
import { modelConfigsAdapterSelectors, useGetMissingModelsQuery } from 'services/api/endpoints/models';
4+
5+
type MissingModelsContextValue = {
6+
missingModelKeys: Set<string>;
7+
isLoading: boolean;
8+
};
9+
10+
const MissingModelsContext = createContext<MissingModelsContextValue>({
11+
missingModelKeys: new Set(),
12+
isLoading: false,
13+
});
14+
15+
export const MissingModelsProvider = ({ children }: PropsWithChildren) => {
16+
const { data, isLoading } = useGetMissingModelsQuery();
17+
18+
const value = useMemo(() => {
19+
const missingModels = modelConfigsAdapterSelectors.selectAll(data ?? { ids: [], entities: {} });
20+
const missingModelKeys = new Set(missingModels.map((m) => m.key));
21+
return { missingModelKeys, isLoading };
22+
}, [data, isLoading]);
23+
24+
return <MissingModelsContext.Provider value={value}>{children}</MissingModelsContext.Provider>;
25+
};
26+
27+
const useMissingModels = () => useContext(MissingModelsContext);
28+
29+
export const useIsModelMissing = (modelKey: string) => {
30+
const { missingModelKeys } = useMissingModels();
31+
return missingModelKeys.has(modelKey);
32+
};

0 commit comments

Comments
 (0)