Skip to content

Commit eadfdda

Browse files
authored
Fix memory leak (#103)
* set start method when init ParaModel * fast safetensors load tensor to device & fix unload lora
1 parent b39dbd6 commit eadfdda

File tree

4 files changed

+9
-4
lines changed

4 files changed

+9
-4
lines changed

diffsynth_engine/models/basic/lora.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,7 @@ def clear(self):
131131
self._lora_dict.clear()
132132
self._frozen_lora_list = []
133133
if self._original_weight is not None:
134-
self.weight.data = self._original_weight
135-
self._original_weight = None
134+
self.weight.data.copy_(self._original_weight)
136135

137136
def forward(self, x):
138137
w_x = super().forward(x)

diffsynth_engine/pipelines/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def from_state_dict(
6262
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
6363
for lora_path, lora_scale in lora_list:
6464
logger.info(f"loading lora from {lora_path} with scale {lora_scale}")
65-
state_dict = load_file(lora_path, device="cpu")
65+
state_dict = load_file(lora_path, device=self.device)
6666
lora_state_dict = self.lora_converter.convert(state_dict)
6767
for model_name, state_dict in lora_state_dict.items():
6868
model = getattr(self, model_name)

diffsynth_engine/utils/loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def load_file(path: str | os.PathLike, device: str = "cpu"):
2424
direct_io=(os.environ.get("FAST_SAFETENSORS_DIRECT_IO", "False").upper() == "TRUE"),
2525
)
2626
logger.info(f"FastSafetensors Load Model End. Time: {time.time() - start_time:.2f}s")
27-
return result
27+
return {k: v.to(device) for k, v in result.items()}
2828
else:
2929
logger.info(f"Safetensors load model from {path}")
3030
start_time = time.time()

diffsynth_engine/utils/parallel.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,12 @@ def __init__(
329329
master_port: int = 29500,
330330
device: str = "cuda",
331331
):
332+
current_method = mp.get_start_method(allow_none=True)
333+
if current_method is None or current_method != 'spawn':
334+
try:
335+
mp.set_start_method('spawn')
336+
except RuntimeError as e:
337+
raise RuntimeError("Failed to set start method to spawn:", e)
332338
super().__init__()
333339
self.world_size = cfg_degree * sp_ulysses_degree * sp_ring_degree * tp_degree
334340
self.queue_in = mp.Queue()

0 commit comments

Comments
 (0)