Skip to content

Commit 1f5421e

Browse files
authored
Merge pull request #23 from LykosAI/fix-dep
Fix hugginface_hub import conflict & ComfyUI model management
2 parents e94fb0a + dd68969 commit 1f5421e

File tree

3 files changed

+60
-65
lines changed

3 files changed

+60
-65
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dependencies = [
2525
"trimesh[easy]",
2626
"albumentations",
2727
"scikit-learn",
28-
"diffusers>=0.25.0"
28+
"diffusers>=0.29.0"
2929
]
3030

3131
[project.optional-dependencies]

src/inference_core_nodes/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
__all__ = ("__version__", "NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS")
22

3-
__version__ = "0.4.1"
3+
__version__ = "0.4.2"
44

55

66
def _get_node_mappings():

src/inference_core_nodes/prompt_expansion/prompt_expansion.py

Lines changed: 58 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import torch
1010
from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed
1111

12+
import comfy.model_management
13+
import comfy.model_base
14+
from comfy.model_base import ModelType
1215
import folder_paths
1316
from comfy import model_management
1417
from comfy.model_patcher import ModelPatcher
@@ -38,76 +41,68 @@ def remove_pattern(x, pattern):
3841
return x
3942

4043

41-
class FooocusExpansion:
42-
def __init__(self, model_directory: str):
43-
model_directory = Path(model_directory)
44-
if not model_directory.exists() or not model_directory.is_dir():
45-
raise ValueError(f"Model directory {model_directory} does not exist")
44+
class ComfyTransformerModel(comfy.model_base.BaseModel):
45+
def __init__(self, model_name: str, model_type=ModelType.EPS, device=None, *args, **kwargs):
46+
# Find the full path to the model
47+
model_path = folder_paths.get_full_path("prompt_expansion", model_name)
48+
if model_path is None:
49+
raise ValueError(f"Model {model_name} not found in prompt_expansion folder.")
4650

47-
self.tokenizer = AutoTokenizer.from_pretrained(model_directory)
48-
self.model = AutoModelForCausalLM.from_pretrained(model_directory)
51+
# If model is a file, use the parent directory
52+
if Path(model_path).is_file():
53+
model_path = str(Path(model_path).parent)
4954

50-
positive_tokens = (
51-
model_directory.joinpath("positive.txt").read_text().splitlines()
52-
)
55+
class MinimalConfig:
56+
def __init__(self):
57+
self.unet_config = {"disable_unet_model_creation": True}
58+
self.latent_format = None
59+
self.custom_operations = None
60+
self.scaled_fp8 = None
61+
self.memory_usage_factor = 1.0
62+
self.manual_cast_dtype = None
63+
self.optimizations = {}
64+
self.sampling_settings = {}
5365

54-
positive_tokens = []
66+
config = MinimalConfig()
5567

56-
self.model.eval()
68+
super().__init__(config, model_type=model_type, device=device)
69+
70+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
71+
self.model = AutoModelForCausalLM.from_pretrained(model_path)
5772

58-
load_device = model_management.text_encoder_device()
73+
self.load_device = comfy.model_management.text_encoder_device()
74+
self.offload_device = comfy.model_management.text_encoder_offload_device()
5975

60-
if "mps" in load_device.type:
61-
load_device = torch.device("cpu")
76+
if "mps" in self.load_device.type:
77+
self.load_device = torch.device("cpu")
6278

63-
if "cpu" not in load_device.type and model_management.should_use_fp16():
79+
if "cpu" not in self.load_device.type and comfy.model_management.should_use_fp16():
6480
self.model.half()
6581

66-
offload_device = model_management.text_encoder_offload_device()
67-
self.patcher = ModelPatcher(
68-
self.model, load_device=load_device, offload_device=offload_device
69-
)
70-
71-
def __call__(self, prompt: str, seed: int) -> str:
72-
model_management.load_model_gpu(self.patcher)
73-
set_seed(seed)
74-
origin = safe_str(prompt)
75-
prompt = origin + fooocus_magic_split[seed % len(fooocus_magic_split)]
76-
77-
tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
78-
tokenized_kwargs.data["input_ids"] = tokenized_kwargs.data["input_ids"].to(
79-
self.patcher.load_device
80-
)
81-
tokenized_kwargs.data["attention_mask"] = tokenized_kwargs.data[
82-
"attention_mask"
83-
].to(self.patcher.load_device)
84-
85-
# https://huggingface.co/blog/introducing-csearch
86-
# https://huggingface.co/docs/transformers/generation_strategies
87-
features = self.model.generate(
88-
**tokenized_kwargs, num_beams=1, max_new_tokens=256, do_sample=True
89-
)
90-
91-
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
92-
result = response[0][len(origin) :]
93-
result = safe_str(result)
94-
result = result.translate(disallowed_chars_table)
95-
return result
96-
97-
def expand_and_join(self, prompt: str, seed: int) -> str:
98-
expansion = self(prompt, seed)
99-
return join_prompts(prompt, expansion)
100-
101-
102-
@cache
103-
def load_expansion_runner(model_name: str):
104-
model_path = folder_paths.get_full_path(MODEL_FOLDER_NAME, model_name)
105-
106-
# If model is a file, use the parent directory
107-
if Path(model_path).is_file():
108-
model_path = str(Path(model_path).parent)
109-
110-
return FooocusExpansion(model_path)
82+
self.model.eval()
83+
self.model.to(self.load_device)
84+
self.device = self.load_device
85+
86+
def apply_model(self, prompt: str, seed: int) -> str:
87+
with torch.no_grad():
88+
origin = safe_str(prompt)
89+
prompt = origin + fooocus_magic_split[seed % len(fooocus_magic_split)]
90+
91+
tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt")
92+
tokenized_kwargs.data["input_ids"] = tokenized_kwargs.data["input_ids"].to(self.load_device)
93+
tokenized_kwargs.data["attention_mask"] = tokenized_kwargs.data["attention_mask"].to(self.load_device)
94+
95+
# https://huggingface.co/blog/introducing-csearch
96+
# https://huggingface.co/docs/transformers/generation_strategies
97+
features = self.model.generate(
98+
**tokenized_kwargs, num_beams=1, max_new_tokens=256, do_sample=True
99+
)
100+
101+
response = self.tokenizer.batch_decode(features, skip_special_tokens=True)
102+
result = response[0][len(origin):]
103+
result = safe_str(result)
104+
result = result.translate(disallowed_chars_table)
105+
return result
111106

112107

113108
class PromptExpansion:
@@ -138,7 +133,7 @@ def INPUT_TYPES(s):
138133
@staticmethod
139134
@torch.no_grad()
140135
def expand_prompt(model_name: str, text: str, seed: int, log_prompt: bool):
141-
expansion = load_expansion_runner(model_name)
136+
expansion_model = ComfyTransformerModel(model_name)
142137

143138
prompt = remove_empty_str([safe_str(text)], default="")[0]
144139

@@ -171,7 +166,7 @@ def expand_prompt(model_name: str, text: str, seed: int, log_prompt: bool):
171166
prompt_parts = [prompt]
172167

173168
for i, part in enumerate(prompt_parts):
174-
expansion_part = expansion(part, seed)
169+
expansion_part = expansion_model.apply_model(part, seed)
175170
full_part = join_prompts(part, expansion_part)
176171
expanded_parts.append(full_part)
177172

0 commit comments

Comments
 (0)