Skip to content

Commit 237daa2

Browse files
authored
Merge pull request #87 from Lupino/main
pass device to processors Annotator
2 parents 996515c + e9af28e commit 237daa2

File tree

3 files changed

+10
-10
lines changed

3 files changed

+10
-10
lines changed

diffsynth/controlnets/processors.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,24 +12,24 @@
1212
]
1313

1414
class Annotator:
15-
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None):
15+
def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'):
1616
if processor_id == "canny":
1717
self.processor = CannyDetector()
1818
elif processor_id == "depth":
19-
self.processor = MidasDetector.from_pretrained(model_path).to("cuda")
19+
self.processor = MidasDetector.from_pretrained(model_path).to(device)
2020
elif processor_id == "softedge":
21-
self.processor = HEDdetector.from_pretrained(model_path).to("cuda")
21+
self.processor = HEDdetector.from_pretrained(model_path).to(device)
2222
elif processor_id == "lineart":
23-
self.processor = LineartDetector.from_pretrained(model_path).to("cuda")
23+
self.processor = LineartDetector.from_pretrained(model_path).to(device)
2424
elif processor_id == "lineart_anime":
25-
self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda")
25+
self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device)
2626
elif processor_id == "openpose":
27-
self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda")
27+
self.processor = OpenposeDetector.from_pretrained(model_path).to(device)
2828
elif processor_id == "tile":
2929
self.processor = None
3030
else:
3131
raise ValueError(f"Unsupported processor_id: {processor_id}")
32-
32+
3333
self.processor_id = processor_id
3434
self.detect_resolution = detect_resolution
3535

diffsynth/pipelines/stable_diffusion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,14 @@ def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config
3939
controlnet_units = []
4040
for config in controlnet_config_units:
4141
controlnet_unit = ControlNetUnit(
42-
Annotator(config.processor_id),
42+
Annotator(config.processor_id, device=self.device),
4343
model_manager.get_model_with_model_path(config.model_path),
4444
config.scale
4545
)
4646
controlnet_units.append(controlnet_unit)
4747
self.controlnet = MultiControlNetManager(controlnet_units)
4848

49-
49+
5050
def fetch_ipadapter(self, model_manager: ModelManager):
5151
if "ipadapter" in model_manager.model:
5252
self.ipadapter = model_manager.ipadapter

diffsynth/pipelines/stable_diffusion_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def fetch_controlnet_models(self, model_manager: ModelManager, controlnet_config
8989
controlnet_units = []
9090
for config in controlnet_config_units:
9191
controlnet_unit = ControlNetUnit(
92-
Annotator(config.processor_id),
92+
Annotator(config.processor_id, device=self.device),
9393
model_manager.get_model_with_model_path(config.model_path),
9494
config.scale
9595
)

0 commit comments

Comments
 (0)