Skip to content

Commit 17ec5d5

Browse files
authored
added download of fusecap for captioning (#27)
1 parent 2b0837a commit 17ec5d5

File tree

4 files changed

+80
-22
lines changed

4 files changed

+80
-22
lines changed

configs/download_list.json

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,5 +388,43 @@
388388
}
389389
]
390390
}
391+
],
392+
"captions_items": [
393+
{
394+
"title": "FuseCap",
395+
"description": "A framework designed to enhance image captioning by incorporating detailed visual information into traditional captions.",
396+
"destination_directory": "app_models",
397+
"destination_subdirectory": "captions/fusecap",
398+
"files": [
399+
{
400+
"file": "config.json",
401+
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/config.json?download=true"
402+
},
403+
{
404+
"file": "preprocessor_config.json",
405+
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/preprocessor_config.json?download=true"
406+
},
407+
{
408+
"file": "pytorch_model.bin",
409+
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/resolve/main/pytorch_model.bin?download=true"
410+
},
411+
{
412+
"file": "special_tokens_map.json",
413+
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/special_tokens_map.json?download=true"
414+
},
415+
{
416+
"file": "tokenizer.json",
417+
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/tokenizer.json?download=true"
418+
},
419+
{
420+
"file": "tokenizer_config.json",
421+
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/tokenizer_config.json?download=true"
422+
},
423+
{
424+
"file": "vocab.txt",
425+
"url": "https://huggingface.co/noamrot/FuseCap_Image_Captioning/raw/main/vocab.txt?download=true"
426+
}
427+
]
428+
}
391429
]
392430
}

src/iartisanxl/app/downloader_dialog.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,15 @@ def init_ui(self):
6767
self.t2i_items_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
6868
t2i_widget.setLayout(self.t2i_items_layout)
6969

70+
captions_widget = QWidget()
71+
self.captions_items_layout = QGridLayout()
72+
self.captions_items_layout.setAlignment(Qt.AlignmentFlag.AlignTop)
73+
captions_widget.setLayout(self.captions_items_layout)
74+
7075
tab_widget.addTab(essentials_widget, "Essentials")
7176
tab_widget.addTab(controlnets_widget, "ControlNet")
7277
tab_widget.addTab(t2i_widget, "T2I Adapters")
78+
tab_widget.addTab(captions_widget, "Captions")
7379
self.main_layout.addWidget(tab_widget)
7480

7581
sdxl_download_button = QPushButton("Download")
@@ -103,6 +109,7 @@ def load_items(self):
103109
"essential_items": self.essentials_items_layout,
104110
"controlnet_items": self.controlnets_items_layout,
105111
"t2i_items": self.t2i_items_layout,
112+
"captions_items": self.captions_items_layout,
106113
}
107114

108115
for category, layout in layouts.items():
@@ -156,7 +163,7 @@ def make_final_directory(self, destination_directory, destination_subdirectory):
156163
return final_directory
157164

158165
def on_start_download(self):
159-
layouts = [self.essentials_items_layout, self.controlnets_items_layout, self.t2i_items_layout]
166+
layouts = [self.essentials_items_layout, self.controlnets_items_layout, self.t2i_items_layout, self.captions_items_layout]
160167
for layout in layouts:
161168
for i in range(layout.count()):
162169
item = layout.itemAt(i).widget()

src/iartisanxl/modules/dataset/dataset_module.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def on_ai_caption(self):
190190
self.generate_captions_thread = GenerateCaptionsThread(self.device)
191191
self.generate_captions_thread.status_update.connect(self.update_status_bar)
192192
self.generate_captions_thread.caption_done.connect(self.on_ai_caption_done)
193+
self.generate_captions_thread.error.connect(self.ai_caption_error)
193194
else:
194195
try:
195196
self.generate_captions_thread.caption_done.disconnect(self.generate_item_ai_caption_done)
@@ -205,6 +206,11 @@ def on_ai_caption(self):
205206

206207
self.generate_captions_thread.start()
207208

209+
def ai_caption_error(self, text):
210+
self.enable_ui()
211+
self.show_snackbar(text)
212+
self.update_status_bar(text)
213+
208214
def on_ai_caption_done(self, text):
209215
self.image_caption_edit.setPlainText(text)
210216
self.enable_ui()
@@ -243,25 +249,27 @@ def on_mass_caption(self):
243249
break
244250

245251
def on_ai_mass_caption(self):
246-
self.disable_ui()
247-
248-
if self.generate_captions_thread is None:
249-
self.generate_captions_thread = GenerateCaptionsThread(self.device)
250-
self.generate_captions_thread.status_update.connect(self.update_status_bar)
251-
self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done)
252-
else:
253-
try:
254-
self.generate_captions_thread.caption_done.disconnect(self.on_ai_caption_done)
255-
except TypeError:
256-
pass
257-
self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done)
252+
if self.dataset_dir is not None and len(self.dataset_dir) > 0:
253+
self.disable_ui()
254+
255+
if self.generate_captions_thread is None:
256+
self.generate_captions_thread = GenerateCaptionsThread(self.device)
257+
self.generate_captions_thread.status_update.connect(self.update_status_bar)
258+
self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done)
259+
self.generate_captions_thread.error.connect(self.ai_caption_error)
260+
else:
261+
try:
262+
self.generate_captions_thread.caption_done.disconnect(self.on_ai_caption_done)
263+
except TypeError:
264+
pass
265+
self.generate_captions_thread.caption_done.connect(self.generate_item_ai_caption_done)
258266

259-
text = self.image_caption_edit.toPlainText()
267+
text = self.image_caption_edit.toPlainText()
260268

261-
self.progress_bar.setMaximum(self.dataset_items_view.item_count)
262-
self.dataset_items_view.get_first_item()
263-
self.update_status_bar("Generating captions...")
264-
self.generate_item_ai_caption(text)
269+
self.progress_bar.setMaximum(self.dataset_items_view.item_count)
270+
self.dataset_items_view.get_first_item()
271+
self.update_status_bar("Generating captions...")
272+
self.generate_item_ai_caption(text)
265273

266274
def generate_item_ai_caption(self, text):
267275
item = self.dataset_items_view.current_item

src/iartisanxl/threads/generate_captions_thread.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
class GenerateCaptionsThread(QThread):
99
status_update = pyqtSignal(str)
1010
caption_done = pyqtSignal(str)
11+
error = pyqtSignal(str)
1112

1213
def __init__(self, device):
1314
super().__init__()
@@ -21,8 +22,14 @@ def __init__(self, device):
2122
def run(self):
2223
if self.model is None:
2324
self.status_update.emit("Loading FuseCap model...")
24-
self.processor = BlipProcessor.from_pretrained("models/captions/fusecap")
25-
self.model = BlipForConditionalGeneration.from_pretrained("models/captions/fusecap").to(self.device)
25+
26+
try:
27+
self.processor = BlipProcessor.from_pretrained("models/captions/fusecap")
28+
self.model = BlipForConditionalGeneration.from_pretrained("models/captions/fusecap").to(self.device)
29+
except OSError:
30+
self.error.emit("Need to download the FuseCap model from the downloader menu.")
31+
return
32+
2633
self.status_update.emit("FuseCap loaded.")
2734

2835
self.status_update.emit("Generating AI caption...")
@@ -32,8 +39,6 @@ def run(self):
3239
buffer.open(QBuffer.ReadWrite)
3340
qimage.save(buffer, "PNG")
3441

35-
print(f"{self.text=}")
36-
3742
raw_image = Image.open(io.BytesIO(buffer.data()))
3843
inputs = self.processor(raw_image, self.text, return_tensors="pt").to(self.device)
3944

0 commit comments

Comments
 (0)