diff --git a/requirements.txt b/requirements.txt index 2e6b0096..d12df2c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ pillow==10.3.0 pyparsing==3.1.2 PySide6==6.7.1 transformers==4.41.2 +gitpython==4.0.11 # PyTorch torch==2.2.2; platform_system != "Windows" diff --git a/taggui/auto_captioning/captioning_thread.py b/taggui/auto_captioning/captioning_thread.py index e0ee8756..0f026d42 100644 --- a/taggui/auto_captioning/captioning_thread.py +++ b/taggui/auto_captioning/captioning_thread.py @@ -31,6 +31,7 @@ get_xcomposer2_error_message, get_xcomposer2_inputs) from models.image_list_model import ImageListModel +from widgets.history_list import HistoryListModel from utils.enums import CaptionDevice, CaptionModelType, CaptionPosition from utils.image import Image from utils.settings import get_tag_separator @@ -143,10 +144,12 @@ class CaptioningThread(QThread): def __init__(self, parent, image_list_model: ImageListModel, selected_image_indices: list[QModelIndex], + history_list_model: HistoryListModel, caption_settings: dict, tag_separator: str, models_directory_path: Path | None): super().__init__(parent) self.image_list_model = image_list_model + self.history_list_model = history_list_model self.selected_image_indices = selected_image_indices self.caption_settings = caption_settings self.tag_separator = tag_separator @@ -396,6 +399,9 @@ def run(self): print(error_message) return processor, model = self.load_processor_and_model(device, model_type) + + self.history_list_model.append(self.caption_settings, model, self.image_list_model, self.selected_image_indices) + # CogVLM and CogAgent have to be monkey patched every time because # `caption_start` might have changed. caption_start = self.caption_settings['caption_start'] diff --git a/taggui/utils/settings.py b/taggui/utils/settings.py index 88e1f0f5..3faf5812 100644 --- a/taggui/utils/settings.py +++ b/taggui/utils/settings.py @@ -10,6 +10,8 @@ 'insert_space_after_tag_separator': True, 'autocomplete_tags': True, 'models_directory_path': '' + # directory_path: '' # added by main_window.load_directory + # more added by auto_captioner.get_caption_settings } diff --git a/taggui/utils/utils.py b/taggui/utils/utils.py index cc62cb95..ba4cc6f3 100644 --- a/taggui/utils/utils.py +++ b/taggui/utils/utils.py @@ -1,3 +1,4 @@ +import git import sys from pathlib import Path @@ -39,3 +40,10 @@ def get_confirmation_dialog_reply(title: str, question: str) -> int: | QMessageBox.StandardButton.Cancel) confirmation_dialog.setDefaultButton(QMessageBox.StandardButton.Yes) return confirmation_dialog.exec() + +def get_repo_infos(path: str) -> dict[str, str]: + repo = git.Repo(path, search_parent_directories=True) + origin = repo.remotes.origin.url + revision = repo.head.commit.hexsha + ret = { "origin": origin, "revision": revision } + return ret diff --git a/taggui/widgets/auto_captioner.py b/taggui/widgets/auto_captioner.py index 4bb1b5fd..c0c34415 100644 --- a/taggui/widgets/auto_captioner.py +++ b/taggui/widgets/auto_captioner.py @@ -1,5 +1,7 @@ import sys from pathlib import Path +import json +from datetime import datetime from PySide6.QtCore import QModelIndex, Qt, Signal, Slot from PySide6.QtGui import QFontMetrics, QTextCursor @@ -11,6 +13,7 @@ from auto_captioning.captioning_thread import CaptioningThread from auto_captioning.models import MODELS, get_model_type from models.image_list_model import ImageListModel +from widgets.history_list import HistoryListModel from utils.big_widgets import TallPushButton from utils.enums import CaptionDevice, CaptionModelType, CaptionPosition from utils.settings import DEFAULT_SETTINGS, get_settings, get_tag_separator @@ -37,7 +40,6 @@ def set_text_edit_height(text_edit: QPlainTextEdit, line_count: int): + text_edit.frameWidth() * 2) text_edit.setFixedHeight(height) - class HorizontalLine(QFrame): def __init__(self): super().__init__() @@ -340,6 +342,37 @@ def get_caption_settings(self) -> dict: } } + def set_captions_settings(self, caption_settings: dict): + if 'model' in caption_settings: self.model_combo_box.setCurrentText(caption_settings['model']) + if 'prompt' in caption_settings: self.prompt_text_edit.setPlainText(caption_settings['prompt']) + if 'caption_start' in caption_settings: self.caption_start_line_edit.setText(caption_settings['caption_start']) + if 'caption_position' in caption_settings: self.caption_position_combo_box.setCurrentText(caption_settings['caption_position']) + if 'device' in caption_settings: self.device_combo_box.setCurrentText(caption_settings['device']) + if 'gpu_index' in caption_settings: self.gpu_index_spin_box.setValue(caption_settings['gpu_index']) + if 'load_in_4_bit' in caption_settings: self.load_in_4_bit_check_box.setChecked(caption_settings['load_in_4_bit']) + if 'remove_tag_separators' in caption_settings: self.remove_tag_separators_check_box.setChecked(caption_settings['remove_tag_separators']) + if 'bad_words' in caption_settings: self.bad_words_line_edit.setText(caption_settings['bad_words']) + if 'forced_words' in caption_settings: self.forced_words_line_edit.setText(caption_settings['forced_words']) + + if 'generation_parameters' in caption_settings: + generation_parameters = caption_settings['generation_parameters'] + if 'min_new_tokens' in generation_parameters: self.min_new_token_count_spin_box.setValue(generation_parameters['min_new_tokens']) + if 'max_new_tokens' in generation_parameters: self.max_new_token_count_spin_box.setValue(generation_parameters['max_new_tokens']) + if 'num_beams' in generation_parameters: self.beam_count_spin_box.setValue(generation_parameters['num_beams']) + if 'length_penalty' in generation_parameters: self.length_penalty_spin_box.setValue(generation_parameters['length_penalty']), + if 'do_sample' in generation_parameters: self.use_sampling_check_box.setChecked(generation_parameters['do_sample']), + if 'temperature' in generation_parameters: self.temperature_spin_box.setValue(generation_parameters['temperature']), + if 'top_k' in generation_parameters: self.top_k_spin_box.setValue(generation_parameters['top_k']), + if 'top_p' in generation_parameters: self.top_p_spin_box.setValue(generation_parameters['top_p']), + if 'repetition_penalty' in generation_parameters: self.repetition_penalty_spin_box.setValue(generation_parameters['repetition_penalty']), + if 'no_repeat_ngram_size' in generation_parameters: self.no_repeat_ngram_size_spin_box.setValue(generation_parameters['no_repeat_ngram_size']) + + if 'wd_tagger_settings' in caption_settings: + wd_tagger_settings = caption_settings['wd_tagger_settings'] + if 'show_probabilities' in wd_tagger_settings: self.show_probabilities_check_box.isChecked(wd_tagger_settings['show_probabilities']), + if 'min_probability' in wd_tagger_settings: self.min_probability_spin_box.value(wd_tagger_settings['min_probability']), + if 'max_tags' in wd_tagger_settings: self.max_tags_spin_box.value(wd_tagger_settings['max_tags']), + if 'tags_to_exclude' in wd_tagger_settings: self.tags_to_exclude_text_edit.toPlainText(wd_tagger_settings['tags_to_exclude']) @Slot() def restore_stdout_and_stderr(): @@ -351,10 +384,11 @@ class AutoCaptioner(QDockWidget): caption_generated = Signal(QModelIndex, str, list) def __init__(self, image_list_model: ImageListModel, - image_list: ImageList): + image_list: ImageList, history_list_model: HistoryListModel): super().__init__() self.image_list_model = image_list_model self.image_list = image_list + self.history_list_model = history_list_model self.settings = get_settings() self.is_captioning = False self.captioning_thread = None @@ -467,7 +501,7 @@ def generate_captions(self): models_directory_path = (Path(models_directory_path) if models_directory_path else None) self.captioning_thread = CaptioningThread( - self, self.image_list_model, selected_image_indices, + self, self.image_list_model, selected_image_indices, self.history_list_model, caption_settings, tag_separator, models_directory_path) self.captioning_thread.text_outputted.connect( self.update_console_text_edit) diff --git a/taggui/widgets/history_list.py b/taggui/widgets/history_list.py new file mode 100644 index 00000000..f235a6dc --- /dev/null +++ b/taggui/widgets/history_list.py @@ -0,0 +1,130 @@ +import json +from pathlib import Path +from datetime import datetime +from typing import Any, Callable, Dict + +from transformers import AutoModel +from PySide6.QtCore import (QAbstractListModel, QModelIndex, Qt) +from PySide6.QtWidgets import (QDockWidget, QListView, QVBoxLayout, QWidget) + +from auto_captioning.models import get_model_type +from models.image_list_model import ImageListModel +from utils.enums import CaptionModelType + +class HistoryListModel(QAbstractListModel): + def __init__(self, repo_infos): + super().__init__() + self.history_list = [] + self.app_infos = repo_infos + self.image_directory_path: Path | None = None + + def data(self, index: QModelIndex, role: Qt.ItemDataRole): + if not index.isValid(): + return None + item = self.history_list[index.row()] + if role == Qt.UserRole: + return item + if role == Qt.DisplayRole: + ret = f"{item['date']} '{item['app']['settings']['prompt'][:20]}'" + return ret + + def rowCount(self, parent: QModelIndex | None=None) -> int: + return len(self.history_list) + + def load_directory(self, image_directory_path: Path): + self.beginResetModel() + self.image_directory_path = image_directory_path + self.history_list = [] + history_path = image_directory_path / "!0_history.jsonl" + if history_path.exists(): + with open(history_path) as file: + for line in file: + entry = json.loads(line) + self.history_list.append(entry) + self.endResetModel() + + def append(self, caption_settings: dict, model: AutoModel, image_list_model: ImageListModel, selected_image_indices: list[QModelIndex]) -> None: + caption_settings = caption_settings.copy() + model_id = caption_settings["model"] + model_type = get_model_type(model_id) + + # clean up settings + del_keys = ["device", "gpu_index"] + for del_key in del_keys: + del caption_settings[del_key] + if model_type == CaptionModelType.WD_TAGGER: + del caption_settings["generation_parameters"] + else: + del caption_settings["wd_tagger_settings"] + if not caption_settings["generation_parameters"]["do_sample"]: + del_keys = ["temperature", "top_k", "top_p", "repetition_penalty", "no_repeat_ngram_size"] + for del_key in del_keys: + del caption_settings["generation_parameters"][del_key] + + # app infos + app = { **self.app_infos, "settings": caption_settings } + + # model infos + model_config = model.config + model = { + "name": model_config.name_or_path, + #"name": model.pretrained_model_name_or_path, + #"name": model.model_name_or_path, + "type": str(model_config.model_type), + #"revision": model_config.revision, + } + + # images + images = [] + if self.image_directory_path != None: + images = sorted([str(image_list_model.images[i.row()].path.relative_to(self.image_directory_path)) for i in selected_image_indices]) + + # collect infos + entry = { + "date": datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + "history_version": 0, + "app": app, + "model": model, + "images": images, + } + + # append to internal list + self.beginInsertRows(QModelIndex(), self.rowCount(), self.rowCount()) + self.history_list.append(entry) + self.endInsertRows() + + # append to history file + if self.image_directory_path != None: + with open(f"{self.image_directory_path}/!0_history.jsonl", "a") as file: + json_str = json.dumps(entry, separators=(',', ':')) + file.write(json_str + "\r\n") + +class HistoryList(QDockWidget): + def __init__(self, model: HistoryListModel): + super().__init__() + self.set_captions_settings: Callable[[Dict[str, Any]], None] | None = None + self.setObjectName('history_list') + self.setWindowTitle('History') + self.setAllowedAreas(Qt.DockWidgetArea.LeftDockWidgetArea + | Qt.DockWidgetArea.RightDockWidgetArea) + + container = QWidget() + + self.listView = QListView() + self.listView.setModel(model) + self.listView.clicked.connect(self.item_clicked) + + layout = QVBoxLayout(container) + layout.addWidget(self.listView) + + self.setWidget(container) + + def item_clicked(self, current: QModelIndex): + if current.isValid(): + index = self.listView.currentIndex() + entry = self.listView.model().data(index, Qt.UserRole) + caption_settings = entry['app']['settings'] + + if self.set_captions_settings is not None: + #print(json.dumps(caption_settings)) + self.set_captions_settings(caption_settings) diff --git a/taggui/widgets/main_window.py b/taggui/widgets/main_window.py index d9a4ef2a..452118fd 100644 --- a/taggui/widgets/main_window.py +++ b/taggui/widgets/main_window.py @@ -20,12 +20,13 @@ from utils.key_press_forwarder import KeyPressForwarder from utils.settings import DEFAULT_SETTINGS, get_settings, get_tag_separator from utils.shortcut_remover import ShortcutRemover -from utils.utils import get_resource_path, pluralize +from utils.utils import get_repo_infos, get_resource_path, pluralize from widgets.all_tags_editor import AllTagsEditor from widgets.auto_captioner import AutoCaptioner from widgets.image_list import ImageList from widgets.image_tags_editor import ImageTagsEditor from widgets.image_viewer import ImageViewer +from widgets.history_list import HistoryList, HistoryListModel ICON_PATH = Path('images/icon.ico') GITHUB_REPOSITORY_URL = 'https://github.com/jhc13/taggui' @@ -49,6 +50,7 @@ def __init__(self, app: QApplication): self.image_list_model, tokenizer, tag_separator) self.image_list_model.proxy_image_list_model = ( self.proxy_image_list_model) + self.history_list_model = HistoryListModel(get_repo_infos(__file__)) self.tag_counter_model = TagCounterModel() self.image_tag_list_model = ImageTagListModel() @@ -64,6 +66,10 @@ def __init__(self, app: QApplication): tag_separator, image_list_image_width) self.addDockWidget(Qt.DockWidgetArea.LeftDockWidgetArea, self.image_list) + self.history_list = HistoryList(self.history_list_model) + self.addDockWidget(Qt.DockWidgetArea.LeftDockWidgetArea, + self.history_list) + self.tabifyDockWidget(self.image_list, self.history_list) self.image_tags_editor = ImageTagsEditor( self.proxy_image_list_model, self.tag_counter_model, self.image_tag_list_model, self.image_list, tokenizer, @@ -76,7 +82,7 @@ def __init__(self, app: QApplication): self.addDockWidget(Qt.DockWidgetArea.RightDockWidgetArea, self.all_tags_editor) self.auto_captioner = AutoCaptioner(self.image_list_model, - self.image_list) + self.image_list, self.history_list_model) self.addDockWidget(Qt.DockWidgetArea.RightDockWidgetArea, self.auto_captioner) self.tabifyDockWidget(self.all_tags_editor, self.auto_captioner) @@ -99,6 +105,7 @@ def __init__(self, app: QApplication): self.undo_action = QAction('Undo', parent=self) self.redo_action = QAction('Redo', parent=self) self.toggle_image_list_action = QAction('Images', parent=self) + self.toggle_history_list_action = QAction('History', parent=self) self.toggle_image_tags_editor_action = QAction('Image Tags', parent=self) self.toggle_all_tags_editor_action = QAction('All Tags', parent=self) @@ -110,6 +117,7 @@ def __init__(self, app: QApplication): .selectionModel()) self.image_list_model.image_list_selection_model = ( self.image_list_selection_model) + self.history_list.set_captions_settings = self.auto_captioner.caption_settings_form.set_captions_settings self.connect_image_list_signals() self.connect_image_tags_editor_signals() self.connect_all_tags_editor_signals() @@ -207,6 +215,7 @@ def load_directory(self, path: Path, select_index: int = 0): self.settings.setValue('directory_path', str(path)) self.setWindowTitle(path.name) self.image_list_model.load_directory(path) + self.history_list_model.load_directory(path) self.image_list.filter_line_edit.clear() self.all_tags_editor.filter_line_edit.clear() # Clear the current index first to make sure that the `currentChanged` @@ -352,11 +361,14 @@ def create_menus(self): view_menu = menu_bar.addMenu('View') self.toggle_image_list_action.setCheckable(True) + self.toggle_history_list_action.setCheckable(True) self.toggle_image_tags_editor_action.setCheckable(True) self.toggle_all_tags_editor_action.setCheckable(True) self.toggle_auto_captioner_action.setCheckable(True) self.toggle_image_list_action.triggered.connect( lambda is_checked: self.image_list.setVisible(is_checked)) + self.toggle_history_list_action.triggered.connect( + lambda is_checked: self.history_list.setVisible(is_checked)) self.toggle_image_tags_editor_action.triggered.connect( lambda is_checked: self.image_tags_editor.setVisible(is_checked)) self.toggle_all_tags_editor_action.triggered.connect( @@ -364,6 +376,7 @@ def create_menus(self): self.toggle_auto_captioner_action.triggered.connect( lambda is_checked: self.auto_captioner.setVisible(is_checked)) view_menu.addAction(self.toggle_image_list_action) + view_menu.addAction(self.toggle_history_list_action) view_menu.addAction(self.toggle_image_tags_editor_action) view_menu.addAction(self.toggle_all_tags_editor_action) view_menu.addAction(self.toggle_auto_captioner_action) @@ -459,6 +472,9 @@ def connect_image_list_signals(self): self.image_list.visibilityChanged.connect( lambda: self.toggle_image_list_action.setChecked( self.image_list.isVisible())) + self.history_list.visibilityChanged.connect( + lambda: self.toggle_history_list_action.setChecked( + self.history_list.isVisible())) @Slot() def update_image_tags(self):