diff --git a/requirements.txt b/requirements.txt index 51decf8..8013d76 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1 @@ -onnxruntime \ No newline at end of file +onnxruntime-gpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/ diff --git a/wd14tagger.py b/wd14tagger.py index c88f583..131dc66 100644 --- a/wd14tagger.py +++ b/wd14tagger.py @@ -38,14 +38,23 @@ models_dir = get_ext_dir("models", mkdir=True) known_models = list(config["models"].keys()) -log("Available ORT providers: " + ", ".join(ort.get_available_providers()), "DEBUG", True) -log("Using ORT providers: " + ", ".join(defaults["ortProviders"]), "DEBUG", True) +log(f"Available ORT providers: {', '.join(ort.get_available_providers())}", "DEBUG", True) +log(f"Using ORT providers: {', '.join(defaults['ortProviders'])}", "DEBUG", True) + +session_options = ort.SessionOptions() +session_options.log_severity_level = 1 +session_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + +def create_session(model_path): + return ort.InferenceSession(model_path, sess_options=session_options, providers=defaults["ortProviders"]) def get_installed_models(): models = filter(lambda x: x.endswith(".onnx"), os.listdir(models_dir)) models = [m for m in models if os.path.exists(os.path.join(models_dir, os.path.splitext(m)[0] + ".csv"))] return models +model_path = os.path.join(models_dir, defaults["model"] + ".onnx") +session = create_session(model_path) async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclude_tags="", replace_underscore=True, trailing_comma=False, client_id=None, node=None): if model_name.endswith(".onnx"): @@ -57,20 +66,26 @@ async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclu name = os.path.join(models_dir, model_name + ".onnx") model = InferenceSession(name, providers=defaults["ortProviders"]) - input = model.get_inputs()[0] - height = input.shape[1] + input_name = model.get_inputs()[0].name + height = model.get_inputs()[0].shape[1] # Reduce to max size and pad with white - ratio = float(height)/max(image.size) - new_size = tuple([int(x*ratio) for x in image.size]) + ratio = float(height) / max(image.size) + new_size = tuple([int(x * ratio) for x in image.size]) image = image.resize(new_size, Image.LANCZOS) square = Image.new("RGB", (height, height), (255, 255, 255)) - square.paste(image, ((height-new_size[0])//2, (height-new_size[1])//2)) + square.paste(image, ((height - new_size[0]) // 2, (height - new_size[1]) // 2)) image = np.array(square).astype(np.float32) image = image[:, :, ::-1] # RGB -> BGR image = np.expand_dims(image, 0) + # Ensure input data is on GPU + ort_inputs = {input_name: ort.OrtValue.ortvalue_from_numpy(image, 'cuda')} + + label_name = model.get_outputs()[0].name + probs = model.run([label_name], ort_inputs)[0] + # Read all tags from csv and locate start of each category tags = [] general_index = None @@ -88,25 +103,20 @@ async def tag(image, model_name, threshold=0.35, character_threshold=0.85, exclu else: tags.append(row[1]) - label_name = model.get_outputs()[0].name - probs = model.run([label_name], {input.name: image})[0] - result = list(zip(tags, probs[0])) # rating = max(result[:general_index], key=lambda x: x[1]) general = [item for item in result[general_index:character_index] if item[1] > threshold] character = [item for item in result[character_index:] if item[1] > character_threshold] - all = character + general + all_tags = character + general remove = [s.strip() for s in exclude_tags.lower().split(",")] - all = [tag for tag in all if tag[0] not in remove] + all_tags = [tag for tag in all_tags if tag[0] not in remove] - res = ("" if trailing_comma else ", ").join((item[0].replace("(", "\\(").replace(")", "\\)") + (", " if trailing_comma else "") for item in all)) + res = ("" if trailing_comma else ", ").join((item[0].replace("(", "\\(").replace(")", "\\)") + (", " if trailing_comma else "") for item in all_tags)) - print(res) return res - async def download_model(model, client_id, node): hf_endpoint = os.getenv("HF_ENDPOINT", defaults["HF_ENDPOINT"]) if not hf_endpoint.startswith("https://"): @@ -127,9 +137,9 @@ async def update_callback(perc): try: await download_to_file( - f"{url}model.onnx", os.path.join(models_dir,f"{model}.onnx"), update_callback, session=session) + f"{url}model.onnx", os.path.join(models_dir, f"{model}.onnx"), update_callback, session=session) await download_to_file( - f"{url}selected_tags.csv", os.path.join(models_dir,f"{model}.csv"), update_callback, session=session) + f"{url}selected_tags.csv", os.path.join(models_dir, f"{model}.csv"), update_callback, session=session) except aiohttp.client_exceptions.ClientConnectorError as err: log("Unable to download model. Download files manually or try using a HF mirror/proxy website by setting the environment variable HF_ENDPOINT=https://.....", "ERROR", True) raise @@ -138,7 +148,6 @@ async def update_callback(perc): return web.Response(status=200) - @PromptServer.instance.routes.get("/pysssss/wd14tagger/tag") async def get_tags(request): if "filename" not in request.rel_url.query: @@ -166,7 +175,6 @@ async def get_tags(request): return web.json_response(await tag(image, model, client_id=request.rel_url.query.get("clientId", ""), node=request.rel_url.query.get("node", ""))) - class WD14Tagger: @classmethod def INPUT_TYPES(s): @@ -190,7 +198,7 @@ def INPUT_TYPES(s): CATEGORY = "image" def tag(self, image, model, threshold, character_threshold, exclude_tags="", replace_underscore=False, trailing_comma=False): - tensor = image*255 + tensor = image * 255 tensor = np.array(tensor, dtype=np.uint8) pbar = comfy.utils.ProgressBar(tensor.shape[0]) @@ -201,7 +209,6 @@ def tag(self, image, model, threshold, character_threshold, exclude_tags="", rep pbar.update(1) return {"ui": {"tags": tags}, "result": (tags,)} - NODE_CLASS_MAPPINGS = { "WD14Tagger|pysssss": WD14Tagger, }