Skip to content

Commit 1a37d14

Browse files
committed
refactor(6): fully batched initial scan
1 parent f960245 commit 1a37d14

File tree

8 files changed

+255
-220
lines changed

8 files changed

+255
-220
lines changed

app/_assets_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
9797
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
9898
p = Path(some_path)
9999
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
100-
return p.name, normalize_tags([root_category, *parent_parts])
100+
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
101101

102102

103103
def normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:

app/assets_scanner.py

Lines changed: 19 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from ._assets_helpers import (
1414
collect_models_files,
15+
compute_relative_filename,
1516
get_comfy_models_folders,
1617
get_name_and_tags_from_asset_path,
1718
list_tree,
@@ -26,9 +27,8 @@
2627
ensure_tags_exist,
2728
escape_like_prefix,
2829
fast_asset_file_check,
29-
insert_meta_from_batch,
30-
insert_tags_from_batch,
3130
remove_missing_tag_for_asset_id,
31+
seed_from_paths_batch,
3232
)
3333
from .database.models import Asset, AssetCacheState, AssetInfo
3434
from .database.services import (
@@ -37,7 +37,6 @@
3737
list_cache_states_with_asset_under_prefixes,
3838
list_unhashed_candidates_under_prefixes,
3939
list_verify_candidates_under_prefixes,
40-
seed_from_path,
4140
)
4241

4342
LOGGER = logging.getLogger(__name__)
@@ -121,62 +120,41 @@ async def sync_seed_assets(roots: list[schemas_in.RootType]) -> None:
121120
if "output" in roots:
122121
paths.extend(list_tree(folder_paths.get_output_directory()))
123122

124-
new_specs: list[tuple[str, int, int, str, list[str]]] = []
123+
specs: list[dict] = []
125124
tag_pool: set[str] = set()
126125
for p in paths:
127126
ap = os.path.abspath(p)
128127
if ap in existing_paths:
129128
skipped_existing += 1
130129
continue
131130
try:
132-
st = os.stat(p, follow_symlinks=True)
131+
st = os.stat(ap, follow_symlinks=True)
133132
except OSError:
134133
continue
135-
if not int(st.st_size or 0):
134+
if not st.st_size:
136135
continue
137136
name, tags = get_name_and_tags_from_asset_path(ap)
138-
new_specs.append((
139-
ap,
140-
int(st.st_size),
141-
getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
142-
name,
143-
tags,
144-
))
137+
specs.append(
138+
{
139+
"abs_path": ap,
140+
"size_bytes": st.st_size,
141+
"mtime_ns": getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)),
142+
"info_name": name,
143+
"tags": tags,
144+
"fname": compute_relative_filename(ap),
145+
}
146+
)
145147
for t in tags:
146148
tag_pool.add(t)
147149

150+
if not specs:
151+
return
148152
async with await create_session() as sess:
149153
if tag_pool:
150154
await ensure_tags_exist(sess, tag_pool, tag_type="user")
151155

152-
pending_tag_links: list[dict] = []
153-
pending_meta_rows: list[dict] = []
154-
for ap, sz, mt, name, tags in new_specs:
155-
await seed_from_path(
156-
sess,
157-
abs_path=ap,
158-
size_bytes=sz,
159-
mtime_ns=mt,
160-
info_name=name,
161-
tags=tags,
162-
owner_id="",
163-
collected_tag_rows=pending_tag_links,
164-
collected_meta_rows=pending_meta_rows,
165-
)
166-
167-
created += 1
168-
if created % 500 == 0:
169-
if pending_tag_links:
170-
await insert_tags_from_batch(sess, tag_rows=pending_tag_links)
171-
pending_tag_links.clear()
172-
if pending_meta_rows:
173-
await insert_meta_from_batch(sess, rows=pending_meta_rows)
174-
pending_meta_rows.clear()
175-
await sess.commit()
176-
if pending_tag_links:
177-
await insert_tags_from_batch(sess, tag_rows=pending_tag_links)
178-
if pending_meta_rows:
179-
await insert_meta_from_batch(sess, rows=pending_meta_rows)
156+
result = await seed_from_paths_batch(sess, specs=specs, owner_id="")
157+
created += result["inserted_infos"]
180158
await sess.commit()
181159
finally:
182160
LOGGER.info(

app/database/helpers/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1+
from .bulk_ops import seed_from_paths_batch
12
from .escape_like import escape_like_prefix
23
from .fast_check import fast_asset_file_check
34
from .filters import apply_metadata_filter, apply_tag_filters
4-
from .meta import insert_meta_from_batch
55
from .ownership import visible_owner_clause
66
from .projection import is_scalar, project_kv
77
from .tags import (
88
add_missing_tag_for_asset_id,
99
ensure_tags_exist,
10-
insert_tags_from_batch,
1110
remove_missing_tag_for_asset_id,
1211
)
1312

@@ -21,7 +20,6 @@
2120
"ensure_tags_exist",
2221
"add_missing_tag_for_asset_id",
2322
"remove_missing_tag_for_asset_id",
24-
"insert_meta_from_batch",
25-
"insert_tags_from_batch",
23+
"seed_from_paths_batch",
2624
"visible_owner_clause",
2725
]

app/database/helpers/bulk_ops.py

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
import os
2+
import uuid
3+
from typing import Iterable, Sequence
4+
5+
import sqlalchemy as sa
6+
from sqlalchemy.dialects import postgresql as d_pg
7+
from sqlalchemy.dialects import sqlite as d_sqlite
8+
from sqlalchemy.ext.asyncio import AsyncSession
9+
10+
from ..models import Asset, AssetCacheState, AssetInfo, AssetInfoMeta, AssetInfoTag
11+
from ..timeutil import utcnow
12+
13+
14+
MAX_BIND_PARAMS = 800
15+
16+
17+
async def seed_from_paths_batch(
18+
session: AsyncSession,
19+
*,
20+
specs: Sequence[dict],
21+
owner_id: str = "",
22+
) -> dict:
23+
"""Each spec is a dict with keys:
24+
- abs_path: str
25+
- size_bytes: int
26+
- mtime_ns: int
27+
- info_name: str
28+
- tags: list[str]
29+
- fname: Optional[str]
30+
"""
31+
if not specs:
32+
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
33+
34+
now = utcnow()
35+
dialect = session.bind.dialect.name
36+
if dialect not in ("sqlite", "postgresql"):
37+
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
38+
39+
asset_rows: list[dict] = []
40+
state_rows: list[dict] = []
41+
path_to_asset: dict[str, str] = {}
42+
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
43+
path_list: list[str] = []
44+
45+
for sp in specs:
46+
ap = os.path.abspath(sp["abs_path"])
47+
aid = str(uuid.uuid4())
48+
iid = str(uuid.uuid4())
49+
path_list.append(ap)
50+
path_to_asset[ap] = aid
51+
52+
asset_rows.append(
53+
{
54+
"id": aid,
55+
"hash": None,
56+
"size_bytes": sp["size_bytes"],
57+
"mime_type": None,
58+
"created_at": now,
59+
}
60+
)
61+
state_rows.append(
62+
{
63+
"asset_id": aid,
64+
"file_path": ap,
65+
"mtime_ns": sp["mtime_ns"],
66+
}
67+
)
68+
asset_to_info[aid] = {
69+
"id": iid,
70+
"owner_id": owner_id,
71+
"name": sp["info_name"],
72+
"asset_id": aid,
73+
"preview_id": None,
74+
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
75+
"created_at": now,
76+
"updated_at": now,
77+
"last_access_time": now,
78+
"_tags": sp["tags"],
79+
"_filename": sp["fname"],
80+
}
81+
82+
# insert all seed Assets (hash=NULL)
83+
ins_asset = d_sqlite.insert(Asset) if dialect == "sqlite" else d_pg.insert(Asset)
84+
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
85+
await session.execute(ins_asset, chunk)
86+
87+
# try to claim AssetCacheState (file_path)
88+
winners_by_path: set[str] = set()
89+
if dialect == "sqlite":
90+
ins_state = (
91+
d_sqlite.insert(AssetCacheState)
92+
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
93+
.returning(AssetCacheState.file_path)
94+
)
95+
else:
96+
ins_state = (
97+
d_pg.insert(AssetCacheState)
98+
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
99+
.returning(AssetCacheState.file_path)
100+
)
101+
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
102+
winners_by_path.update((await session.execute(ins_state, chunk)).scalars().all())
103+
104+
all_paths_set = set(path_list)
105+
losers_by_path = all_paths_set - winners_by_path
106+
lost_assets = [path_to_asset[p] for p in losers_by_path]
107+
if lost_assets: # losers get their Asset removed
108+
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
109+
await session.execute(sa.delete(Asset).where(Asset.id.in_(id_chunk)))
110+
111+
if not winners_by_path:
112+
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
113+
114+
# insert AssetInfo only for winners
115+
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
116+
if dialect == "sqlite":
117+
ins_info = (
118+
d_sqlite.insert(AssetInfo)
119+
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
120+
.returning(AssetInfo.id)
121+
)
122+
else:
123+
ins_info = (
124+
d_pg.insert(AssetInfo)
125+
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
126+
.returning(AssetInfo.id)
127+
)
128+
129+
inserted_info_ids: set[str] = set()
130+
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
131+
inserted_info_ids.update((await session.execute(ins_info, chunk)).scalars().all())
132+
133+
# build and insert tag + meta rows for the AssetInfo
134+
tag_rows: list[dict] = []
135+
meta_rows: list[dict] = []
136+
if inserted_info_ids:
137+
for row in winner_info_rows:
138+
iid = row["id"]
139+
if iid not in inserted_info_ids:
140+
continue
141+
for t in row["_tags"]:
142+
tag_rows.append({
143+
"asset_info_id": iid,
144+
"tag_name": t,
145+
"origin": "automatic",
146+
"added_at": now,
147+
})
148+
if row["_filename"]:
149+
meta_rows.append(
150+
{
151+
"asset_info_id": iid,
152+
"key": "filename",
153+
"ordinal": 0,
154+
"val_str": row["_filename"],
155+
"val_num": None,
156+
"val_bool": None,
157+
"val_json": None,
158+
}
159+
)
160+
161+
await bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
162+
return {
163+
"inserted_infos": len(inserted_info_ids),
164+
"won_states": len(winners_by_path),
165+
"lost_states": len(losers_by_path),
166+
}
167+
168+
169+
async def bulk_insert_tags_and_meta(
170+
session: AsyncSession,
171+
*,
172+
tag_rows: list[dict],
173+
meta_rows: list[dict],
174+
max_bind_params: int,
175+
) -> None:
176+
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
177+
- tag_rows keys: asset_info_id, tag_name, origin, added_at
178+
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
179+
"""
180+
dialect = session.bind.dialect.name
181+
if tag_rows:
182+
if dialect == "sqlite":
183+
ins_links = (
184+
d_sqlite.insert(AssetInfoTag)
185+
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
186+
)
187+
elif dialect == "postgresql":
188+
ins_links = (
189+
d_pg.insert(AssetInfoTag)
190+
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
191+
)
192+
else:
193+
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
194+
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
195+
await session.execute(ins_links, chunk)
196+
if meta_rows:
197+
if dialect == "sqlite":
198+
ins_meta = (
199+
d_sqlite.insert(AssetInfoMeta)
200+
.on_conflict_do_nothing(
201+
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
202+
)
203+
)
204+
elif dialect == "postgresql":
205+
ins_meta = (
206+
d_pg.insert(AssetInfoMeta)
207+
.on_conflict_do_nothing(
208+
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
209+
)
210+
)
211+
else:
212+
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
213+
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
214+
await session.execute(ins_meta, chunk)
215+
216+
217+
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
218+
if not rows:
219+
return []
220+
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
221+
for i in range(0, len(rows), rows_per_stmt):
222+
yield rows[i:i + rows_per_stmt]
223+
224+
225+
def _iter_chunks(seq, n: int):
226+
for i in range(0, len(seq), n):
227+
yield seq[i:i + n]
228+
229+
230+
def _rows_per_stmt(cols: int) -> int:
231+
return max(1, MAX_BIND_PARAMS // max(1, cols))

0 commit comments

Comments
 (0)