From 59ed595fcc984c019caae8bfa9d580d0b26ce6ba Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Sat, 11 May 2024 21:18:05 +0200 Subject: [PATCH 1/5] support batch processing --- wd14tagger.py | 63 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 40 insertions(+), 23 deletions(-) diff --git a/wd14tagger.py b/wd14tagger.py index c88f583..fafb2df 100644 --- a/wd14tagger.py +++ b/wd14tagger.py @@ -47,7 +47,7 @@ def get_installed_models(): return models -async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclude_tags="", replace_underscore=True, trailing_comma=False, client_id=None, node=None): +async def tag(batch, model_name, threshold=0.35, character_threshold=0.85, exclude_tags="", replace_underscore=True, trailing_comma=False, client_id=None, node=None): if model_name.endswith(".onnx"): model_name = model_name[0:-5] installed = list(get_installed_models()) @@ -60,16 +60,17 @@ async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclu input = model.get_inputs()[0] height = input.shape[1] - # Reduce to max size and pad with white - ratio = float(height)/max(image.size) - new_size = tuple([int(x*ratio) for x in image.size]) - image = image.resize(new_size, Image.LANCZOS) - square = Image.new("RGB", (height, height), (255, 255, 255)) - square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2)) + for i in range(len(batch)): + # Reduce to max size and pad with white + ratio = float(height)/max(batch[i].size) + new_size = tuple([int(x*ratio) for x in batch[i].size]) + batch[i] = batch[i].resize(new_size, Image.LANCZOS) + square = Image.new("RGB", (height, height), (255, 255, 255)) + square.paste(batch[i], ((height-new_size[0])//2, (height-new_size[1])//2)) - image = np.array(square).astype(np.float32) - image = image[:, :, ::-1] # RGB -> BGR - image = np.expand_dims(image, 0) + batch[i] = np.array(square).astype(np.float32) + batch[i] = batch[i][:, :, ::-1] # RGB -> BGR + batch[i] = np.expand_dims(batch[i], 0) # Read all tags from csv and locate start of each category tags = [] @@ -88,22 +89,32 @@ async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclu else: tags.append(row[1]) + # imgs = np.array([im for im in batch]) + + probs = [] label_name = model.get_outputs()[0].name - probs = model.run([label_name], {input.name: image})[0] + for img in batch: + probs.append(model.run([label_name], {input.name: img})[0]) + # probs = probs[: len(batch)] + # probs = model.run([label_name], {input.name: imgs})[0] + + # print(probs) + + res = [] - result = list(zip(tags, probs[0])) + for i in range(len(batch)): + result = list(zip(tags, probs[i][0])) - # rating = max(result[:general_index], key=lambda x: x[1]) - general = [item for item in result[general_index:character_index] if item[1] > threshold] - character = [item for item in result[character_index:] if item[1] > character_threshold] + # rating = max(result[:general_index], key=lambda x: x[1]) + general = [item for item in result[general_index:character_index] if item[1] > threshold] + character = [item for item in result[character_index:] if item[1] > character_threshold] - all = character + general - remove = [s.strip() for s in exclude_tags.lower().split(",")] - all = [tag for tag in all if tag[0] not in remove] + all = character + general + remove = [s.strip() for s in exclude_tags.lower().split(",")] + all = [tag for tag in all if tag[0] not in remove] - res = ("" if trailing_comma else ", ").join((item[0].replace("(", "\\(").replace(")", "\\)") + (", " if trailing_comma else "") for item in all)) + res.append(("" if trailing_comma else ", ").join((item[0].replace("(", "\\(").replace(")", "\\)") + (", " if trailing_comma else "") for item in all))) - print(res) return res @@ -180,6 +191,7 @@ def INPUT_TYPES(s): "replace_underscore": ("BOOLEAN", {"default": defaults["replace_underscore"]}), "trailing_comma": ("BOOLEAN", {"default": defaults["trailing_comma"]}), "exclude_tags": ("STRING", {"default": defaults["exclude_tags"]}), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 16}), }} RETURN_TYPES = ("STRING",) @@ -189,16 +201,21 @@ def INPUT_TYPES(s): CATEGORY = "image" - def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False): + def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False, batch_size=1): tensor = image*255 tensor = np.array(tensor, dtype=np.uint8) pbar = comfy.utils.ProgressBar(tensor.shape[0]) tags = [] + batch = [] for i in range(tensor.shape[0]): image = Image.fromarray(tensor[i]) - tags.append(wait_for_async(lambda: tag(image, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma))) - pbar.update(1) + batch.append(image) + if len(batch) == batch_size or i == tensor.shape[0] -1: + tags = tags + wait_for_async(lambda: tag(batch, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma)) + pbar.update(len(batch)) + batch = [] + print(tags) return {"ui": {"tags": tags}, "result": (tags,)} From 7594ff65c762076f37bd2e25600bbf8001846a45 Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Thu, 26 Dec 2024 21:12:43 +0100 Subject: [PATCH 2/5] working batch support --- wd14tagger.py | 46 ++++++++++++++++++++++++++++++++++++---------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/wd14tagger.py b/wd14tagger.py index fafb2df..53f45d8 100644 --- a/wd14tagger.py +++ b/wd14tagger.py @@ -202,22 +202,48 @@ def INPUT_TYPES(s): CATEGORY = "image" def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False, batch_size=1): - tensor = image*255 - tensor = np.array(tensor, dtype=np.uint8) + if not isinstance(image, list): + images = [image] + else: + images = image - pbar = comfy.utils.ProgressBar(tensor.shape[0]) + pbar = comfy.utils.ProgressBar(len(images)) tags = [] batch = [] - for i in range(tensor.shape[0]): - image = Image.fromarray(tensor[i]) - batch.append(image) - if len(batch) == batch_size or i == tensor.shape[0] -1: - tags = tags + wait_for_async(lambda: tag(batch, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma)) - pbar.update(len(batch)) - batch = [] + + for image in images: + tensor = image*255 + tensor = np.array(tensor, dtype=np.uint8) + + for i in range(tensor.shape[0]): + image = Image.fromarray(tensor[i]) + batch.append(image) + if len(batch) == batch_size or i == tensor.shape[0] -1: + tags = tags + wait_for_async(lambda: tag(batch, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma)) + pbar.update(len(batch)) + batch = [] + print(tags) + return {"ui": {"tags": tags}, "result": (tags,)} + # def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False): + # if not isinstance(image, list): + # images = [image] + # else: + # images = image + # pbar = comfy.utils.ProgressBar(len(images)) + # for image in images: + # tensor = image*255 + # tensor = np.array(tensor, dtype=np.uint8) + + # tags = [] + # for i in range(tensor.shape[0]): + # image = Image.fromarray(tensor[i]) + # tags.append(wait_for_async(lambda: tag(image, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma))) + # pbar.update(1) + # return {"ui": {"tags": tags}, "result": (tags,)} + NODE_CLASS_MAPPINGS = { "WD14Tagger|pysssss": WD14Tagger, From 6650cb27c2bd331b4ad945f9430c6b544bc7bea4 Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Thu, 26 Dec 2024 21:19:13 +0100 Subject: [PATCH 3/5] remove commented code --- wd14tagger.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/wd14tagger.py b/wd14tagger.py index 53f45d8..70aa3d5 100644 --- a/wd14tagger.py +++ b/wd14tagger.py @@ -227,24 +227,6 @@ def tag(self, image, model, threshold, character_threshold, exclude_tags="", rep return {"ui": {"tags": tags}, "result": (tags,)} - # def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False): - # if not isinstance(image, list): - # images = [image] - # else: - # images = image - # pbar = comfy.utils.ProgressBar(len(images)) - # for image in images: - # tensor = image*255 - # tensor = np.array(tensor, dtype=np.uint8) - - # tags = [] - # for i in range(tensor.shape[0]): - # image = Image.fromarray(tensor[i]) - # tags.append(wait_for_async(lambda: tag(image, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma))) - # pbar.update(1) - # return {"ui": {"tags": tags}, "result": (tags,)} - - NODE_CLASS_MAPPINGS = { "WD14Tagger|pysssss": WD14Tagger, } From bf9813a2134698ef3dd971137ee4bd9a39003c5e Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Thu, 26 Dec 2024 21:38:10 +0100 Subject: [PATCH 4/5] move new batch_size parameter to optional parameters to not break legacy workflows --- wd14tagger.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/wd14tagger.py b/wd14tagger.py index 70aa3d5..48a5995 100644 --- a/wd14tagger.py +++ b/wd14tagger.py @@ -183,16 +183,21 @@ class WD14Tagger: def INPUT_TYPES(s): extra = [name for name, _ in (os.path.splitext(m) for m in get_installed_models()) if name not in known_models] models = known_models + extra - return {"required": { - "image": ("IMAGE", ), - "model": (models, { "default": defaults["model"] }), - "threshold": ("FLOAT", {"default": defaults["threshold"], "min": 0.0, "max": 1, "step": 0.05}), - "character_threshold": ("FLOAT", {"default": defaults["character_threshold"], "min": 0.0, "max": 1, "step": 0.05}), - "replace_underscore": ("BOOLEAN", {"default": defaults["replace_underscore"]}), - "trailing_comma": ("BOOLEAN", {"default": defaults["trailing_comma"]}), - "exclude_tags": ("STRING", {"default": defaults["exclude_tags"]}), - "batch_size": ("INT", {"default": 1, "min": 1, "max": 16}), - }} + return { + "required": { + "image": ("IMAGE", ), + "model": (models, { "default": defaults["model"] }), + "threshold": ("FLOAT", {"default": defaults["threshold"], "min": 0.0, "max": 1, "step": 0.05}), + "character_threshold": ("FLOAT", {"default": defaults["character_threshold"], "min": 0.0, "max": 1, "step": 0.05}), + "replace_underscore": ("BOOLEAN", {"default": defaults["replace_underscore"]}), + "trailing_comma": ("BOOLEAN", {"default": defaults["trailing_comma"]}), + "exclude_tags": ("STRING", {"default": defaults["exclude_tags"]}), + + }, + "optional": { + "batch_size": ("INT", {"default": 1, "min": 1, "max": 128}), + } + } RETURN_TYPES = ("STRING",) OUTPUT_IS_LIST = (True,) From 018e3af6bf1414cff2a35546c2198b02d0b4154b Mon Sep 17 00:00:00 2001 From: Jonas Krauss Date: Sun, 23 Mar 2025 11:51:31 +0100 Subject: [PATCH 5/5] working cache mechanism, eva02 default tagger, print memory usage for debug, fix issue when running with --gpu-only --- pysssss.json | 4 +-- wd14tagger.py | 73 ++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 69 insertions(+), 8 deletions(-) diff --git a/pysssss.json b/pysssss.json index 3f3d291..666cd27 100644 --- a/pysssss.json +++ b/pysssss.json @@ -2,11 +2,11 @@ "name": "WD14Tagger", "logging": false, "settings": { - "model": "wd-v1-4-moat-tagger-v2", + "model": "wd-eva02-large-tagger-v3", "threshold": 0.35, "character_threshold": 0.85, "exclude_tags": "", - "ortProviders": ["CUDAExecutionProvider", "CPUExecutionProvider"], + "ortProviders": ["CUDAExecutionProvider","CPUExecutionProvider"], "HF_ENDPOINT": "https://huggingface.co" }, "models": { diff --git a/wd14tagger.py b/wd14tagger.py index 48a5995..782872b 100644 --- a/wd14tagger.py +++ b/wd14tagger.py @@ -1,6 +1,7 @@ # https://huggingface.co/spaces/SmilingWolf/wd-v1-4-tags import comfy.utils +import comfy.model_management import asyncio import aiohttp import numpy as np @@ -10,9 +11,11 @@ import onnxruntime as ort from onnxruntime import InferenceSession from PIL import Image +import hashlib from server import PromptServer from aiohttp import web import folder_paths +import torch from .pysssss import get_ext_dir, get_comfy_dir, download_to_file, update_node_status, wait_for_async, get_extension_config, log sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy")) @@ -54,6 +57,13 @@ async def tag(batch, model_name, threshold=0.35, character_threshold=0.85, exclu if not any(model_name + ".onnx" in s for s in installed): await download_model(model_name, client_id, node) + # unloaded = comfy.model_management.free_memory(1e30, torch.device(torch.cuda.current_device())) + # if unloaded is not None and len(unloaded) > 0: + # torch.cuda.empty_cache() + # torch.cuda.ipc_collect() + unloaded = comfy.model_management.unload_all_models() + print(f"Unloaded models: {unloaded}") + name = os.path.join(models_dir, model_name + ".onnx") model = InferenceSession(name, providers=defaults["ortProviders"]) @@ -179,6 +189,16 @@ async def get_tags(request): class WD14Tagger: + def __init__(self): + self.hash = {} # settings hash --> list of tuples (hash of images, tags) + self.max_cached = 100 # avoid oom + + def get_cache_size(self): + items = 0 + for settings_hash in self.hash: + items += len(self.hash[settings_hash]) + return items + @classmethod def INPUT_TYPES(s): extra = [name for name, _ in (os.path.splitext(m) for m in get_installed_models()) if name not in known_models] @@ -192,7 +212,6 @@ def INPUT_TYPES(s): "replace_underscore": ("BOOLEAN", {"default": defaults["replace_underscore"]}), "trailing_comma": ("BOOLEAN", {"default": defaults["trailing_comma"]}), "exclude_tags": ("STRING", {"default": defaults["exclude_tags"]}), - }, "optional": { "batch_size": ("INT", {"default": 1, "min": 1, "max": 128}), @@ -207,29 +226,71 @@ def INPUT_TYPES(s): CATEGORY = "image" def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False, batch_size=1): + if not isinstance(image, list): images = [image] else: images = image - pbar = comfy.utils.ProgressBar(len(images)) - tags = [] + batches = [] batch = [] + mem = comfy.model_management.get_total_memory(torch_total_too=True) + total_vram = mem[0] / (1024 * 1024) + total_vram_torch = mem[1] / (1024 * 1024) + print("Total VRAM {:0.0f} MB, total Torch VRAM {:0.0f} MB".format(total_vram, total_vram_torch)) + + # build hash for cache + settings_hash = f'{len(model)}{hash(model)}-{threshold}-{character_threshold}-{len(exclude_tags)}{hash(exclude_tags)}-{replace_underscore}-{trailing_comma}-{batch_size}' + img_hashes = [] + for image in images: tensor = image*255 - tensor = np.array(tensor, dtype=np.uint8) + tensor = np.array(tensor.cpu(), dtype=np.uint8) for i in range(tensor.shape[0]): image = Image.fromarray(tensor[i]) + img_hashes.append(hashlib.md5(image.tobytes()).hexdigest()) batch.append(image) if len(batch) == batch_size or i == tensor.shape[0] -1: - tags = tags + wait_for_async(lambda: tag(batch, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma)) - pbar.update(len(batch)) + batches.append(batch) batch = [] + img_hash = "-".join(img_hashes) + + # check cache for entry + if settings_hash in self.hash: + for stored_tags in self.hash[settings_hash]: + if stored_tags[0] == img_hash: + print(f'hashed tags: {stored_tags[1]}') + return {"ui": {"tags": stored_tags[1]}, "result": (stored_tags[1],)} + + pbar = comfy.utils.ProgressBar(len(images)) + tags = [] + for batch in batches: + tags = tags + wait_for_async(lambda: tag(batch, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma)) + pbar.update(len(batch)) + print(tags) + # store tags in cache + if settings_hash in self.hash: + self.hash[settings_hash].insert(0, (img_hash, tags)) + else: + self.hash[settings_hash] = [(img_hash, tags)] + + # prune cache to avoid oom + while self.get_cache_size() > self.max_cached: + # TODO: improve by using LRU mechanism + for settings_hash in self.hash: + if len(self.hash[settings_hash]) > 0: del self.hash[settings_hash][-1] + if self.get_cache_size() <= self.max_cached: break + + mem = comfy.model_management.get_total_memory(torch_total_too=True) + total_vram = mem[0] / (1024 * 1024) + total_vram_torch = mem[1] / (1024 * 1024) + print("Total VRAM {:0.0f} MB, total Torch VRAM {:0.0f} MB".format(total_vram, total_vram_torch)) + return {"ui": {"tags": tags}, "result": (tags,)} NODE_CLASS_MAPPINGS = {