Skip to content

Commit 249173f

Browse files
committed
remove extraneous warnings about overwriting trigger terms
1 parent 794ef86 commit 249173f

File tree

2 files changed

+8
-16
lines changed

2 files changed

+8
-16
lines changed

ldm/generate.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -976,17 +976,18 @@ def set_model(self, model_name):
976976
self.generators = {}
977977

978978
seed_everything(random.randrange(0, np.iinfo(np.uint32).max))
979-
if self.embedding_path is not None:
979+
if self.embedding_path and not model_data.get("ti_embeddings_loaded"):
980980
print(f'>> Loading embeddings from {self.embedding_path}')
981981
for root, _, files in os.walk(self.embedding_path):
982982
for name in files:
983983
ti_path = os.path.join(root, name)
984984
self.model.textual_inversion_manager.load_textual_inversion(
985985
ti_path, defer_injecting_tokens=True
986986
)
987-
print(
988-
f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
989-
)
987+
model_data["ti_embeddings_loaded"] = True
988+
print(
989+
f'>> Textual inversion triggers: {", ".join(sorted(self.model.textual_inversion_manager.get_all_trigger_strings()))}'
990+
)
990991

991992
self.model_name = model_name
992993
self._set_sampler() # requires self.model_name to be set first

ldm/invoke/model_manager.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,7 @@ def get_model(self, model_name: str):
103103
requested_model = self.models[model_name]["model"]
104104
print(f">> Retrieving model {model_name} from system RAM cache")
105105
self.models[model_name]["model"] = self._model_from_cpu(requested_model)
106-
width = self.models[model_name]["width"]
107-
height = self.models[model_name]["height"]
108-
hash = self.models[model_name]["hash"]
109-
110-
else: # we're about to load a new model, so potentially offload the least recently used one
106+
else:
111107
requested_model, width, height, hash = self._load_model(model_name)
112108
self.models[model_name] = {
113109
"model": requested_model,
@@ -118,13 +114,8 @@ def get_model(self, model_name: str):
118114

119115
self.current_model = model_name
120116
self._push_newest_model(model_name)
121-
return {
122-
"model": requested_model,
123-
"width": width,
124-
"height": height,
125-
"hash": hash,
126-
}
127-
117+
return self.models[model_name]
118+
128119
def default_model(self) -> str | None:
129120
"""
130121
Returns the name of the default model, or None

0 commit comments

Comments
 (0)