diff --git a/wd14tagger.py b/wd14tagger.py index c88f583..dca934d 100644 --- a/wd14tagger.py +++ b/wd14tagger.py @@ -25,6 +25,7 @@ "replace_underscore": False, "trailing_comma": False, "exclude_tags": "", + "show_confidence": False, "ortProviders": ["CUDAExecutionProvider", "CPUExecutionProvider"], "HF_ENDPOINT": "https://huggingface.co" } @@ -47,7 +48,7 @@ def get_installed_models(): return models -async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclude_tags="", replace_underscore=True, trailing_comma=False, client_id=None, node=None): +async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclude_tags="", replace_underscore=True, trailing_comma=False, show_confidence=False, client_id=None, node=None): if model_name.endswith(".onnx"): model_name = model_name[0:-5] installed = list(get_installed_models()) @@ -97,11 +98,14 @@ async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclu general = [item for item in result[general_index:character_index] if item[1] > threshold] character = [item for item in result[character_index:] if item[1] > character_threshold] - all = character + general + all = sorted(general + character, key=lambda x: x[1], reverse=True) remove = [s.strip() for s in exclude_tags.lower().split(",")] all = [tag for tag in all if tag[0] not in remove] - res = ("" if trailing_comma else ", ").join((item[0].replace("(", "\\(").replace(")", "\\)") + (", " if trailing_comma else "") for item in all)) + if show_confidence: + res = ("" if trailing_comma else ", ").join((f"{item[0].replace('(', '\\(').replace(')', '\\)')}: {item[1]:.2f}" + ("," if trailing_comma else "")) for item in all) + else: + res = ("" if trailing_comma else ", ").join((item[0].replace("(", "\\(").replace(")", "\\)") + (", " if trailing_comma else "") for item in all)) print(res) return res @@ -180,6 +184,7 @@ def INPUT_TYPES(s): "replace_underscore": ("BOOLEAN", {"default": defaults["replace_underscore"]}), "trailing_comma": ("BOOLEAN", {"default": defaults["trailing_comma"]}), "exclude_tags": ("STRING", {"default": defaults["exclude_tags"]}), + "show_confidence": ("BOOLEAN", {"default": defaults["show_confidence"]}), }} RETURN_TYPES = ("STRING",) @@ -189,7 +194,7 @@ def INPUT_TYPES(s): CATEGORY = "image" - def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False): + def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False, show_confidence=False): tensor = image*255 tensor = np.array(tensor, dtype=np.uint8) @@ -197,7 +202,7 @@ def tag(self, image, model, threshold, character_threshold, exclude_tags="", rep tags = [] for i in range(tensor.shape[0]): image = Image.fromarray(tensor[i]) - tags.append(wait_for_async(lambda: tag(image, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma))) + tags.append(wait_for_async(lambda: tag(image, model, threshold, character_threshold, exclude_tags, replace_underscore, trailing_comma, show_confidence))) pbar.update(1) return {"ui": {"tags": tags}, "result": (tags,)}