66from collections import OrderedDict
77
88from modules import scripts , shared
9- from modules . devices import device , torch_gc , cpu
9+ from modules import devices
1010import local_groundingdino
1111
1212
@@ -96,15 +96,15 @@ def show_boxes(image_np, boxes, color=(255, 0, 0, 255), thickness=2, show_index=
9696def clear_dino_cache ():
9797 dino_model_cache .clear ()
9898 gc .collect ()
99- torch_gc ()
99+ devices . torch_gc ()
100100
101101
102102def load_dino_model (dino_checkpoint , dino_install_success ):
103103 print (f"Initializing GroundingDINO { dino_checkpoint } " )
104104 if dino_checkpoint in dino_model_cache :
105105 dino = dino_model_cache [dino_checkpoint ]
106106 if shared .cmd_opts .lowvram :
107- dino .to (device = device )
107+ dino .to (device = devices . device )
108108 else :
109109 clear_dino_cache ()
110110 if dino_install_success :
@@ -121,7 +121,7 @@ def load_dino_model(dino_checkpoint, dino_install_success):
121121 dino_model_info [dino_checkpoint ]["url" ], dino_model_dir )
122122 dino .load_state_dict (clean_state_dict (
123123 checkpoint ['model' ]), strict = False )
124- dino .to (device = device )
124+ dino .to (device = devices . device )
125125 dino_model_cache [dino_checkpoint ] = dino
126126 dino .eval ()
127127 return dino
@@ -148,11 +148,11 @@ def get_grounding_output(model, image, caption, box_threshold):
148148 caption = caption .strip ()
149149 if not caption .endswith ("." ):
150150 caption = caption + "."
151- image = image .to (device )
151+ image = image .to (devices . device )
152152 with torch .no_grad ():
153153 outputs = model (image [None ], captions = [caption ])
154154 if shared .cmd_opts .lowvram :
155- model .to (cpu )
155+ model .to (devices . cpu )
156156 logits = outputs ["pred_logits" ].sigmoid ()[0 ] # (nq, 256)
157157 boxes = outputs ["pred_boxes" ][0 ] # (nq, 4)
158158
@@ -183,5 +183,5 @@ def dino_predict_internal(input_image, dino_model_name, text_prompt, box_thresho
183183 boxes_filt [i ][:2 ] -= boxes_filt [i ][2 :] / 2
184184 boxes_filt [i ][2 :] += boxes_filt [i ][:2 ]
185185 gc .collect ()
186- torch_gc ()
186+ devices . torch_gc ()
187187 return boxes_filt , install_success
0 commit comments