-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathpng_info_to_captions.py
More file actions
311 lines (251 loc) · 10.6 KB
/
png_info_to_captions.py
File metadata and controls
311 lines (251 loc) · 10.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
png_info_to_captions.py - rewritten to rely on read_write_metadata.py for all metadata I/O.
Key changes:
- Removed ad-hoc metadata parsing with Pillow. All metadata access now goes through read_write_metadata.read_metadata().
- Kept prompt parsing and simplification logic, but applied only to the 'parameters' field returned by read_metadata().
- Threaded processing preserved, with an explicit lock passed to worker to avoid globals.
- Safer handling of missing/invalid metadata; invisible characters stripped before parsing.
- CLI flags kept compatible with the previous script.
"""
import argparse
import os
import re
import string
import sys
import threading
from pathlib import Path
from typing import List
from tqdm import tqdm
# Local import of the unified metadata helper
# Make sure read_write_metadata.py is in the same folder or on PYTHONPATH.
import read_write_metadata as rwm
# ---------------- Utilities ----------------
_allowed_chars = set(string.printable + "\t\n\r")
def remove_invisible_characters(s: str) -> str:
return "".join(ch for ch in s if ch in _allowed_chars)
def flat_prompt(text: str) -> str:
return text.replace("\n", ", ")
def parse_parameters(value: str) -> dict:
"""
Parse Stable Diffusion style 'parameters' text into a dict.
Expects a 'Negative prompt:' section and a 'Steps:' section.
"""
value = remove_invisible_characters(value)
steps_index = value.rfind("Steps:")
if steps_index < 0:
raise ValueError("Invalid format: 'Steps:' keyword not found.")
prompt_and_neg_prompt = value[:steps_index].strip()
other_params = value[steps_index:]
neg_prompt_index = prompt_and_neg_prompt.rfind("Negative prompt:")
if neg_prompt_index == -1:
raise ValueError("Invalid format: 'Negative prompt:' keyword not found.")
prompt = (
prompt_and_neg_prompt[:neg_prompt_index]
.strip()
.replace("parameters:", "")
.strip()
)
negative_prompt = (
prompt_and_neg_prompt[neg_prompt_index:]
.replace("Negative prompt:", "")
.strip()
)
kv_pairs = [kv.strip() for kv in other_params.split(",") if kv]
kv_dict = {}
for kv in kv_pairs:
k, *rest = kv.split(":")
v = ":".join(rest).strip()
kv_dict[k.strip()] = v
result = {"prompt": prompt, "Negative prompt": negative_prompt}
result.update(kv_dict)
return result
def simplify_prompt(prompt: str, exclude_patterns: List[str], is_flat_prompt: bool) -> str:
# Remove unwanted text first
for pattern in exclude_patterns:
prompt = re.sub(pattern, "", prompt)
if is_flat_prompt:
prompt = flat_prompt(prompt)
prompt = re.sub(r"\(:0\)", "", prompt)
prompt = re.sub(r"\( :0\)", " ", prompt)
prompt = re.sub(r"\\\(", "bracket_open", prompt)
prompt = re.sub(r"\\\)", "bracket_close", prompt)
# Insert comma after brackets if needed
prompt = re.sub(r"(?<=\))([^,\s\]])", r",\1", prompt)
prompt = re.sub(r"(?<=\])([^,\s\]])", r",\1", prompt)
# Remove lora tags like <lora:name:1.0>
prompt = re.sub(r"<[^>]+>", "", prompt)
# (A:B:1.21) -> A, B
prompt = re.sub(r"\(([^:]+):([^:]+):\d+(\.\d+)?\)", r"\1, \2", prompt)
# (A:1.21) -> A
prompt = re.sub(r"\(([^:]+):\d+(\.\d+)?\)", r"\1", prompt)
# [A|B|C] -> A, B, C
prompt = re.sub(r"\[([^\]]+)\]", lambda m: ", ".join(m.group(1).split("|")), prompt)
# (A) -> A
prompt = re.sub(r"\(([^()]+)\)", r"\1", prompt)
# A: 1.21 -> A
prompt = re.sub(r"([^:]+):\s*\d+(\.\d+)?", r"\1", prompt)
# Cleanup: Remove occurrences of ', ,'
prompt = re.sub(r",[ ]+,", ", ", prompt)
# Free '|' -> removed
prompt = prompt.replace("|", "")
# Remove remaining brackets
prompt = prompt.replace("(", "").replace(")", "")
prompt = prompt.replace("[", "").replace("]", "")
# ':' not followed by digit or space-digit -> removed
prompt = re.sub(r":(?![ \d])", "", prompt)
# Remove BREAK
prompt = prompt.replace("BREAK", "")
# Normalize commas spacing
prompt = re.sub(r"\s*,\s*", ", ", prompt)
prompt = re.sub("bracket_open", r"(", prompt)
prompt = re.sub("bracket_close", r")", prompt)
# Cleanup odd patterns and numbers
prompt = re.sub(r",\s*:\d+(\.\d+)?,", ", ", prompt) # ', :0.2, '
prompt = re.sub(r",\s*\d+(\.\d+)?,", ", ", prompt) # ', 0.7, '
prompt = re.sub(r"^,|,$", "", prompt) # leading/trailing comma
prompt = re.sub(r"^ *, *", "", prompt)
prompt = re.sub(r" *, *$", "", prompt)
prompt = re.sub(r"[\n\r]", " ", prompt)
prompt = re.sub(r", ,", ", ", prompt)
prompt = re.sub(r"\d+\.\d+", "", prompt) # remove decimals remaining
prompt = re.sub(r" +", " ", prompt) # collapse spaces
prompt = re.sub(r",[ ]+,", ",", prompt)
return prompt.strip()
def extract_prompt_from_parameters_text(parameters_text: str,
exclude_patterns: List[str],
use_original_prompt: bool,
include_all_metadata: bool,
is_flat_prompt: bool) -> str:
"""
Given the full 'parameters' text, return the caption to write.
- include_all_metadata: return the whole parameters text as-is.
- use_original_prompt: return the original 'prompt' section only (no simplification).
- else: parse and simplify the prompt section.
"""
if not parameters_text:
return ""
# Sanitize any headers or nulls possibly present in EXIF-encoded strings.
parameters_text = rwm._strip_uc_header_and_nulls_str(parameters_text)
if include_all_metadata:
result = parameters_text
return flat_prompt(result) if is_flat_prompt else result
# Try to parse with our SD-parameters parser
try:
parsed = parse_parameters(parameters_text)
except Exception:
# Fallback: treat everything as a raw prompt if parsing fails.
raw = parameters_text
return flat_prompt(raw) if is_flat_prompt else raw
prompt = (parsed.get("prompt") or "").strip()
if not prompt:
# If no explicit 'prompt', return entire parameters text
raw = parameters_text
return flat_prompt(raw) if is_flat_prompt else raw
if use_original_prompt:
result = prompt
else:
result = simplify_prompt(prompt, exclude_patterns, is_flat_prompt)
return result
def get_caption_for_image(image_path: Path,
exclude_patterns: List[str],
use_original_prompt: bool,
include_all_metadata: bool,
is_flat_prompt: bool,
use_existing_txt: bool) -> str:
txt_path = image_path.with_suffix(".txt")
if use_existing_txt and txt_path.exists():
try:
return txt_path.read_text(encoding="utf-8")
except Exception:
pass
# Use the unified metadata reader
meta = rwm.read_metadata(image_path)
parameters_text = meta.get("parameters") or ""
if not parameters_text:
return ""
return extract_prompt_from_parameters_text(
parameters_text=parameters_text,
exclude_patterns=exclude_patterns,
use_original_prompt=use_original_prompt,
include_all_metadata=include_all_metadata,
is_flat_prompt=is_flat_prompt,
)
def process_image(image_path: Path,
progress_bar,
lock: threading.Lock,
exclude_patterns: List[str],
use_original_prompt: bool,
include_all_metadata: bool,
is_flat_prompt: bool,
use_existing_txt: bool) -> None:
try:
caption = get_caption_for_image(
image_path=image_path,
exclude_patterns=exclude_patterns,
use_original_prompt=use_original_prompt,
include_all_metadata=include_all_metadata,
is_flat_prompt=is_flat_prompt,
use_existing_txt=use_existing_txt,
)
if caption:
image_path.with_suffix(".txt").write_text(caption, encoding="utf-8")
except Exception as e:
print(f"Error processing {image_path}: {e}")
finally:
with lock:
progress_bar.update(1)
def gather_images(root: Path) -> List[Path]:
exts = {".png", ".jpg", ".jpeg", ".webp"}
out: List[Path] = []
for dirpath, _, filenames in os.walk(root):
d = Path(dirpath)
for f in filenames:
if f.lower().endswith(tuple(exts)):
out.append(d / f)
return out
def main() -> int:
parser = argparse.ArgumentParser(description="Extract captions from image metadata using read_write_metadata.py")
parser.add_argument("directory", type=str, help="Directory containing images to process")
parser.add_argument("--use_original_prompt", action="store_true", default=False, help="Use original 'prompt' section verbatim")
parser.add_argument("--include_all_metadata", action="store_true", default=False, help="Write full 'parameters' text to the caption")
parser.add_argument("--flat_prompt", action="store_true", default=False, help="Flatten newlines into commas")
parser.add_argument("--exclude_patterns", type=str, nargs="*", default=[], help="Regex patterns to remove from prompts before simplification")
parser.add_argument("--use_existing_txt", action="store_true", default=False, help="If a .txt exists next to the image, reuse it")
args = parser.parse_args()
root = Path(args.directory)
if not root.is_dir():
print(f"Error: {root} is not a valid directory.")
return 1
images = gather_images(root)
lock = threading.Lock()
with tqdm(total=len(images), desc="Processing images") as progress_bar:
threads: List[threading.Thread] = []
for image_path in images:
t = threading.Thread(
target=process_image,
name="process_image",
args=(
image_path,
progress_bar,
lock,
args.exclude_patterns,
args.use_original_prompt,
args.include_all_metadata,
args.flat_prompt,
args.use_existing_txt,
),
)
threads.append(t)
t.start()
# Limit concurrency
if len(threads) >= 5:
for tt in threads:
tt.join()
threads = []
for tt in threads:
tt.join()
return 0
if __name__ == "__main__":
raise SystemExit(main())