Skip to content

Commit e8aed67

Browse files
feat(app): add workflow library get_counts method
Get the counts of workflows for the given tags and/or categories. Made a separate method bc get_many will deserialize all matching workflows, which is unnecessary for this use case.
1 parent f56dd01 commit e8aed67

File tree

3 files changed

+74
-6
lines changed

3 files changed

+74
-6
lines changed

invokeai/app/api/routers/workflows.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,13 @@ async def get_workflow_thumbnail(
219219
return response
220220
except Exception:
221221
raise HTTPException(status_code=404)
222+
223+
224+
@workflows_router.get("/counts", operation_id="get_counts")
225+
async def get_counts(
226+
tags: Optional[list[str]] = Query(default=None, description="The tags to include"),
227+
categories: Optional[list[WorkflowCategory]] = Query(default=None, description="The categories to include"),
228+
) -> int:
229+
"""Gets a the count of workflows that include the specified tags and categories"""
230+
231+
return ApiDependencies.invoker.services.workflow_records.get_counts(tags=tags, categories=categories)

invokeai/app/services/workflow_records/workflow_records_base.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,3 +49,12 @@ def get_many(
4949
) -> PaginatedResults[WorkflowRecordListItemDTO]:
5050
"""Gets many workflows."""
5151
pass
52+
53+
@abstractmethod
54+
def get_counts(
55+
self,
56+
tags: Optional[list[str]],
57+
categories: Optional[list[WorkflowCategory]],
58+
) -> int:
59+
"""Gets the count of workflows for the given tags and categories."""
60+
pass

invokeai/app/services/workflow_records/workflow_records_sqlite.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,55 @@ def get_many(
237237
total=total,
238238
)
239239

240+
def get_counts(
241+
self,
242+
tags: Optional[list[str]],
243+
categories: Optional[list[WorkflowCategory]],
244+
) -> int:
245+
cursor = self._conn.cursor()
246+
247+
# Start with an empty list of conditions and params
248+
conditions: list[str] = []
249+
params: list[str | int] = []
250+
251+
if tags:
252+
# Construct a list of conditions for each tag
253+
tags_conditions = ["tags LIKE ?" for _ in tags]
254+
tags_conditions_joined = " OR ".join(tags_conditions)
255+
tags_condition = f"({tags_conditions_joined})"
256+
257+
# And the params for the tags, case-insensitive
258+
tags_params = [f"%{t.strip()}%" for t in tags]
259+
260+
conditions.append(tags_condition)
261+
params.extend(tags_params)
262+
263+
if categories:
264+
# Ensure all categories are valid (is this necessary?)
265+
assert all(c in WorkflowCategory for c in categories)
266+
267+
# Construct a placeholder string for the number of categories
268+
placeholders = ", ".join("?" for _ in categories)
269+
270+
# Construct the condition string & params
271+
conditions.append(f"category IN ({placeholders})")
272+
params.extend([category.value for category in categories])
273+
274+
stmt = """--sql
275+
SELECT COUNT(*)
276+
FROM workflow_library
277+
"""
278+
279+
if conditions:
280+
# If there are conditions, add a WHERE clause and then join the conditions
281+
stmt += " WHERE "
282+
283+
all_conditions = " AND ".join(conditions)
284+
stmt += all_conditions
285+
286+
cursor.execute(stmt, tuple(params))
287+
return cursor.fetchone()[0]
288+
240289
def _sync_default_workflows(self) -> None:
241290
"""Syncs default workflows to the database. Internal use only."""
242291

@@ -261,13 +310,13 @@ def _sync_default_workflows(self) -> None:
261310
bytes_ = path.read_bytes()
262311
workflow_from_file = WorkflowValidator.validate_json(bytes_)
263312

264-
assert workflow_from_file.id.startswith("default_"), (
265-
f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
266-
)
313+
assert workflow_from_file.id.startswith(
314+
"default_"
315+
), f'Invalid default workflow ID (must start with "default_"): {workflow_from_file.id}'
267316

268-
assert workflow_from_file.meta.category is WorkflowCategory.Default, (
269-
f"Invalid default workflow category: {workflow_from_file.meta.category}"
270-
)
317+
assert (
318+
workflow_from_file.meta.category is WorkflowCategory.Default
319+
), f"Invalid default workflow category: {workflow_from_file.meta.category}"
271320

272321
workflows_from_file.append(workflow_from_file)
273322

0 commit comments

Comments
 (0)