|
12 | 12 | ] |
13 | 13 |
|
14 | 14 | class Annotator: |
15 | | - def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None): |
| 15 | + def __init__(self, processor_id: Processor_id, model_path="models/Annotators", detect_resolution=None, device='cuda'): |
16 | 16 | if processor_id == "canny": |
17 | 17 | self.processor = CannyDetector() |
18 | 18 | elif processor_id == "depth": |
19 | | - self.processor = MidasDetector.from_pretrained(model_path).to("cuda") |
| 19 | + self.processor = MidasDetector.from_pretrained(model_path).to(device) |
20 | 20 | elif processor_id == "softedge": |
21 | | - self.processor = HEDdetector.from_pretrained(model_path).to("cuda") |
| 21 | + self.processor = HEDdetector.from_pretrained(model_path).to(device) |
22 | 22 | elif processor_id == "lineart": |
23 | | - self.processor = LineartDetector.from_pretrained(model_path).to("cuda") |
| 23 | + self.processor = LineartDetector.from_pretrained(model_path).to(device) |
24 | 24 | elif processor_id == "lineart_anime": |
25 | | - self.processor = LineartAnimeDetector.from_pretrained(model_path).to("cuda") |
| 25 | + self.processor = LineartAnimeDetector.from_pretrained(model_path).to(device) |
26 | 26 | elif processor_id == "openpose": |
27 | | - self.processor = OpenposeDetector.from_pretrained(model_path).to("cuda") |
| 27 | + self.processor = OpenposeDetector.from_pretrained(model_path).to(device) |
28 | 28 | elif processor_id == "tile": |
29 | 29 | self.processor = None |
30 | 30 | else: |
31 | 31 | raise ValueError(f"Unsupported processor_id: {processor_id}") |
32 | | - |
| 32 | + |
33 | 33 | self.processor_id = processor_id |
34 | 34 | self.detect_resolution = detect_resolution |
35 | 35 |
|
|
0 commit comments