Skip to content

Commit 5f69b95

Browse files
Use device by reference for dino (#189)
1 parent d80220e commit 5f69b95

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

scripts/dino.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from collections import OrderedDict
77

88
from modules import scripts, shared
9-
from modules.devices import device, torch_gc, cpu
9+
from modules import devices
1010
import local_groundingdino
1111

1212

@@ -96,15 +96,15 @@ def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=
9696
def clear_dino_cache():
9797
dino_model_cache.clear()
9898
gc.collect()
99-
torch_gc()
99+
devices.torch_gc()
100100

101101

102102
def load_dino_model(dino_checkpoint, dino_install_success):
103103
print(f"Initializing GroundingDINO {dino_checkpoint}")
104104
if dino_checkpoint in dino_model_cache:
105105
dino = dino_model_cache[dino_checkpoint]
106106
if shared.cmd_opts.lowvram:
107-
dino.to(device=device)
107+
dino.to(device=devices.device)
108108
else:
109109
clear_dino_cache()
110110
if dino_install_success:
@@ -121,7 +121,7 @@ def load_dino_model(dino_checkpoint, dino_install_success):
121121
dino_model_info[dino_checkpoint]["url"], dino_model_dir)
122122
dino.load_state_dict(clean_state_dict(
123123
checkpoint['model']), strict=False)
124-
dino.to(device=device)
124+
dino.to(device=devices.device)
125125
dino_model_cache[dino_checkpoint] = dino
126126
dino.eval()
127127
return dino
@@ -148,11 +148,11 @@ def get_grounding_output(model, image, caption, box_threshold):
148148
caption = caption.strip()
149149
if not caption.endswith("."):
150150
caption = caption + "."
151-
image = image.to(device)
151+
image = image.to(devices.device)
152152
with torch.no_grad():
153153
outputs = model(image[None], captions=[caption])
154154
if shared.cmd_opts.lowvram:
155-
model.to(cpu)
155+
model.to(devices.cpu)
156156
logits = outputs["pred_logits"].sigmoid()[0] # (nq, 256)
157157
boxes = outputs["pred_boxes"][0] # (nq, 4)
158158

@@ -183,5 +183,5 @@ def dino_predict_internal(input_image, dino_model_name, text_prompt, box_thresho
183183
boxes_filt[i][:2] -= boxes_filt[i][2:] / 2
184184
boxes_filt[i][2:] += boxes_filt[i][:2]
185185
gc.collect()
186-
torch_gc()
186+
devices.torch_gc()
187187
return boxes_filt, install_success

0 commit comments

Comments
 (0)