|
9 | 9 | import torch |
10 | 10 | from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed |
11 | 11 |
|
| 12 | +import comfy.model_management |
| 13 | +import comfy.model_base |
| 14 | +from comfy.model_base import ModelType |
12 | 15 | import folder_paths |
13 | 16 | from comfy import model_management |
14 | 17 | from comfy.model_patcher import ModelPatcher |
@@ -38,76 +41,68 @@ def remove_pattern(x, pattern): |
38 | 41 | return x |
39 | 42 |
|
40 | 43 |
|
41 | | -class FooocusExpansion: |
42 | | - def __init__(self, model_directory: str): |
43 | | - model_directory = Path(model_directory) |
44 | | - if not model_directory.exists() or not model_directory.is_dir(): |
45 | | - raise ValueError(f"Model directory {model_directory} does not exist") |
| 44 | +class ComfyTransformerModel(comfy.model_base.BaseModel): |
| 45 | + def __init__(self, model_name: str, model_type=ModelType.EPS, device=None, *args, **kwargs): |
| 46 | + # Find the full path to the model |
| 47 | + model_path = folder_paths.get_full_path("prompt_expansion", model_name) |
| 48 | + if model_path is None: |
| 49 | + raise ValueError(f"Model {model_name} not found in prompt_expansion folder.") |
46 | 50 |
|
47 | | - self.tokenizer = AutoTokenizer.from_pretrained(model_directory) |
48 | | - self.model = AutoModelForCausalLM.from_pretrained(model_directory) |
| 51 | + # If model is a file, use the parent directory |
| 52 | + if Path(model_path).is_file(): |
| 53 | + model_path = str(Path(model_path).parent) |
49 | 54 |
|
50 | | - positive_tokens = ( |
51 | | - model_directory.joinpath("positive.txt").read_text().splitlines() |
52 | | - ) |
| 55 | + class MinimalConfig: |
| 56 | + def __init__(self): |
| 57 | + self.unet_config = {"disable_unet_model_creation": True} |
| 58 | + self.latent_format = None |
| 59 | + self.custom_operations = None |
| 60 | + self.scaled_fp8 = None |
| 61 | + self.memory_usage_factor = 1.0 |
| 62 | + self.manual_cast_dtype = None |
| 63 | + self.optimizations = {} |
| 64 | + self.sampling_settings = {} |
53 | 65 |
|
54 | | - positive_tokens = [] |
| 66 | + config = MinimalConfig() |
55 | 67 |
|
56 | | - self.model.eval() |
| 68 | + super().__init__(config, model_type=model_type, device=device) |
| 69 | + |
| 70 | + self.tokenizer = AutoTokenizer.from_pretrained(model_path) |
| 71 | + self.model = AutoModelForCausalLM.from_pretrained(model_path) |
57 | 72 |
|
58 | | - load_device = model_management.text_encoder_device() |
| 73 | + self.load_device = comfy.model_management.text_encoder_device() |
| 74 | + self.offload_device = comfy.model_management.text_encoder_offload_device() |
59 | 75 |
|
60 | | - if "mps" in load_device.type: |
61 | | - load_device = torch.device("cpu") |
| 76 | + if "mps" in self.load_device.type: |
| 77 | + self.load_device = torch.device("cpu") |
62 | 78 |
|
63 | | - if "cpu" not in load_device.type and model_management.should_use_fp16(): |
| 79 | + if "cpu" not in self.load_device.type and comfy.model_management.should_use_fp16(): |
64 | 80 | self.model.half() |
65 | 81 |
|
66 | | - offload_device = model_management.text_encoder_offload_device() |
67 | | - self.patcher = ModelPatcher( |
68 | | - self.model, load_device=load_device, offload_device=offload_device |
69 | | - ) |
70 | | - |
71 | | - def __call__(self, prompt: str, seed: int) -> str: |
72 | | - model_management.load_model_gpu(self.patcher) |
73 | | - set_seed(seed) |
74 | | - origin = safe_str(prompt) |
75 | | - prompt = origin + fooocus_magic_split[seed % len(fooocus_magic_split)] |
76 | | - |
77 | | - tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt") |
78 | | - tokenized_kwargs.data["input_ids"] = tokenized_kwargs.data["input_ids"].to( |
79 | | - self.patcher.load_device |
80 | | - ) |
81 | | - tokenized_kwargs.data["attention_mask"] = tokenized_kwargs.data[ |
82 | | - "attention_mask" |
83 | | - ].to(self.patcher.load_device) |
84 | | - |
85 | | - # https://huggingface.co/blog/introducing-csearch |
86 | | - # https://huggingface.co/docs/transformers/generation_strategies |
87 | | - features = self.model.generate( |
88 | | - **tokenized_kwargs, num_beams=1, max_new_tokens=256, do_sample=True |
89 | | - ) |
90 | | - |
91 | | - response = self.tokenizer.batch_decode(features, skip_special_tokens=True) |
92 | | - result = response[0][len(origin) :] |
93 | | - result = safe_str(result) |
94 | | - result = result.translate(disallowed_chars_table) |
95 | | - return result |
96 | | - |
97 | | - def expand_and_join(self, prompt: str, seed: int) -> str: |
98 | | - expansion = self(prompt, seed) |
99 | | - return join_prompts(prompt, expansion) |
100 | | - |
101 | | - |
102 | | -@cache |
103 | | -def load_expansion_runner(model_name: str): |
104 | | - model_path = folder_paths.get_full_path(MODEL_FOLDER_NAME, model_name) |
105 | | - |
106 | | - # If model is a file, use the parent directory |
107 | | - if Path(model_path).is_file(): |
108 | | - model_path = str(Path(model_path).parent) |
109 | | - |
110 | | - return FooocusExpansion(model_path) |
| 82 | + self.model.eval() |
| 83 | + self.model.to(self.load_device) |
| 84 | + self.device = self.load_device |
| 85 | + |
| 86 | + def apply_model(self, prompt: str, seed: int) -> str: |
| 87 | + with torch.no_grad(): |
| 88 | + origin = safe_str(prompt) |
| 89 | + prompt = origin + fooocus_magic_split[seed % len(fooocus_magic_split)] |
| 90 | + |
| 91 | + tokenized_kwargs = self.tokenizer(prompt, return_tensors="pt") |
| 92 | + tokenized_kwargs.data["input_ids"] = tokenized_kwargs.data["input_ids"].to(self.load_device) |
| 93 | + tokenized_kwargs.data["attention_mask"] = tokenized_kwargs.data["attention_mask"].to(self.load_device) |
| 94 | + |
| 95 | + # https://huggingface.co/blog/introducing-csearch |
| 96 | + # https://huggingface.co/docs/transformers/generation_strategies |
| 97 | + features = self.model.generate( |
| 98 | + **tokenized_kwargs, num_beams=1, max_new_tokens=256, do_sample=True |
| 99 | + ) |
| 100 | + |
| 101 | + response = self.tokenizer.batch_decode(features, skip_special_tokens=True) |
| 102 | + result = response[0][len(origin):] |
| 103 | + result = safe_str(result) |
| 104 | + result = result.translate(disallowed_chars_table) |
| 105 | + return result |
111 | 106 |
|
112 | 107 |
|
113 | 108 | class PromptExpansion: |
@@ -138,7 +133,7 @@ def INPUT_TYPES(s): |
138 | 133 | @staticmethod |
139 | 134 | @torch.no_grad() |
140 | 135 | def expand_prompt(model_name: str, text: str, seed: int, log_prompt: bool): |
141 | | - expansion = load_expansion_runner(model_name) |
| 136 | + expansion_model = ComfyTransformerModel(model_name) |
142 | 137 |
|
143 | 138 | prompt = remove_empty_str([safe_str(text)], default="")[0] |
144 | 139 |
|
@@ -171,7 +166,7 @@ def expand_prompt(model_name: str, text: str, seed: int, log_prompt: bool): |
171 | 166 | prompt_parts = [prompt] |
172 | 167 |
|
173 | 168 | for i, part in enumerate(prompt_parts): |
174 | | - expansion_part = expansion(part, seed) |
| 169 | + expansion_part = expansion_model.apply_model(part, seed) |
175 | 170 | full_part = join_prompts(part, expansion_part) |
176 | 171 | expanded_parts.append(full_part) |
177 | 172 |
|
|
0 commit comments