Skip to content

Commit b39774a

Browse files
feat(app): add searching by tags to workflow library APIs
1 parent 8988539 commit b39774a

File tree

4 files changed

+72
-33
lines changed

4 files changed

+72
-33
lines changed

invokeai/app/api/routers/workflows.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,13 @@ async def list_workflows(
109109
"""Gets a page of workflows"""
110110
workflows_with_thumbnails: list[WorkflowRecordListItemWithThumbnailDTO] = []
111111
workflows = ApiDependencies.invoker.services.workflow_records.get_many(
112-
order_by=order_by, direction=direction, page=page, per_page=per_page, query=query, categories=categories
112+
order_by=order_by,
113+
direction=direction,
114+
page=page,
115+
per_page=per_page,
116+
query=query,
117+
categories=categories,
118+
tags=tags,
113119
)
114120
for workflow in workflows.items:
115121
workflows_with_thumbnails.append(

invokeai/app/services/workflow_records/workflow_records_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def get_many(
4545
page: int,
4646
per_page: Optional[int],
4747
query: Optional[str],
48+
tags: Optional[list[str]],
4849
) -> PaginatedResults[WorkflowRecordListItemDTO]:
4950
"""Gets many workflows."""
5051
pass

invokeai/app/services/workflow_records/workflow_records_common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def from_dict(cls, data: dict[str, Any]) -> "WorkflowRecordDTO":
116116
class WorkflowRecordListItemDTO(WorkflowRecordDTOBase):
117117
description: str = Field(description="The description of the workflow.")
118118
category: WorkflowCategory = Field(description="The description of the workflow.")
119+
tags: str = Field(description="The tags of the workflow.")
119120

120121

121122
WorkflowRecordListItemDTOValidator = TypeAdapter(WorkflowRecordListItemDTO)

invokeai/app/services/workflow_records/workflow_records_sqlite.py

Lines changed: 63 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -124,64 +124,95 @@ def get_many(
124124
page: int = 0,
125125
per_page: Optional[int] = None,
126126
query: Optional[str] = None,
127+
tags: Optional[list[str]] = None,
127128
) -> PaginatedResults[WorkflowRecordListItemDTO]:
128129
# sanitize!
129130
assert order_by in WorkflowRecordOrderBy
130131
assert direction in SQLiteDirection
131132

132-
main_params: list[int | str] = []
133-
count_params: list[int | str] = []
133+
# We will construct the query dynamically based on the query params
134134

135-
if categories:
136-
assert all(c in WorkflowCategory for c in categories)
137-
question_marks = ", ".join("?" for _ in categories)
138-
count_query = f"SELECT COUNT(*) FROM workflow_library WHERE category IN ({question_marks})"
139-
main_query = f"""
140-
SELECT
141-
workflow_id,
142-
category,
143-
name,
144-
description,
145-
created_at,
146-
updated_at,
147-
opened_at
148-
FROM workflow_library
149-
WHERE category IN ({question_marks})
150-
"""
151-
main_params.extend([category.value for category in categories])
152-
count_params.extend([category.value for category in categories])
153-
else:
154-
count_query = "SELECT COUNT(*) FROM workflow_library"
155-
main_query = """
135+
# The main query to get the workflows / counts
136+
main_query = """
156137
SELECT
157138
workflow_id,
158139
category,
159140
name,
160141
description,
161142
created_at,
162143
updated_at,
163-
opened_at
144+
opened_at,
145+
tags
164146
FROM workflow_library
165147
"""
148+
count_query = "SELECT COUNT(*) FROM workflow_library"
149+
150+
# Start with an empty list of conditions and params
151+
conditions: list[str] = []
152+
params: list[str | int] = []
153+
154+
if categories:
155+
# Categories is a list of WorkflowCategory enum values, and a single string in the DB
156+
157+
# Ensure all categories are valid (is this necessary?)
158+
assert all(c in WorkflowCategory for c in categories)
159+
160+
# Construct a placeholder string for the number of categories
161+
placeholders = ", ".join("?" for _ in categories)
162+
163+
# Construct the condition string & params
164+
category_condition = f"category IN ({placeholders})"
165+
category_params = [category.value for category in categories]
166+
167+
conditions.append(category_condition)
168+
params.extend(category_params)
169+
170+
if tags:
171+
# Tags is a list of strings, and a single string in the DB
172+
# The string in the DB has no guaranteed format
166173

174+
# Construct a list of conditions for each tag
175+
tags_conditions = ["tags LIKE ?" for _ in tags]
176+
tags_conditions_joined = " OR ".join(tags_conditions)
177+
tags_condition = f"({tags_conditions_joined})"
178+
179+
# And the params for the tags, case-insensitive
180+
tags_params = [f"%{t.strip()}%" for t in tags]
181+
182+
conditions.append(tags_condition)
183+
params.extend(tags_params)
184+
185+
# Ignore whitespace in the query
167186
stripped_query = query.strip() if query else None
168187
if stripped_query:
188+
# Construct a wildcard query for the name, description, and tags
169189
wildcard_query = "%" + stripped_query + "%"
170-
if categories:
171-
main_query += " AND (name LIKE ? OR description LIKE ?) "
172-
count_query += " AND (name LIKE ? OR description LIKE ?)"
173-
else:
174-
main_query += " WHERE name LIKE ? OR description LIKE ? "
175-
count_query += " WHERE name LIKE ? OR description LIKE ?"
176-
main_params.extend([wildcard_query, wildcard_query])
177-
count_params.extend([wildcard_query, wildcard_query])
190+
query_condition = "(name LIKE ? OR description LIKE ? OR tags LIKE ?)"
191+
192+
conditions.append(query_condition)
193+
params.extend([wildcard_query])
194+
195+
if conditions:
196+
# If there are conditions, add a WHERE clause and then join the conditions
197+
main_query += " WHERE "
198+
count_query += " WHERE "
199+
200+
all_conditions = " AND ".join(conditions)
201+
main_query += all_conditions
202+
count_query += all_conditions
203+
204+
# After this point, the query and params differ for the main query and the count query
205+
main_params = params.copy()
206+
count_params = params.copy()
178207

208+
# Main query also gets ORDER BY and LIMIT/OFFSET
179209
main_query += f" ORDER BY {order_by.value} {direction.value}"
180210

181211
if per_page:
182212
main_query += " LIMIT ? OFFSET ?"
183213
main_params.extend([per_page, page * per_page])
184214

215+
# Put a ring on it
185216
main_query += ";"
186217
count_query += ";"
187218

0 commit comments

Comments
 (0)