diff --git a/demo_colmap.py b/demo_colmap.py index 836af172..d2d4403d 100644 --- a/demo_colmap.py +++ b/demo_colmap.py @@ -110,12 +110,12 @@ def demo_fn(args): print(f"Using dtype: {dtype}") # Run VGGT for camera and depth estimation - model = VGGT() + model = VGGT().to(device) _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" - model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) + state_dict = torch.hub.load_state_dict_from_url(_URL, map_location=device) + model.load_state_dict(state_dict) model.eval() - model = model.to(device) - print(f"Model loaded") + print(f"Model loaded to {device}") # Get image paths and preprocess them image_dir = os.path.join(args.scene_dir, "images") diff --git a/demo_gradio.py b/demo_gradio.py index 466a5ff4..a622fd14 100644 --- a/demo_gradio.py +++ b/demo_gradio.py @@ -29,13 +29,12 @@ print("Initializing and loading VGGT model...") # model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model -model = VGGT() +model = VGGT().to(device) _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" -model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) - - +state_dict = torch.hub.load_state_dict_from_url(_URL, map_location=device) +model.load_state_dict(state_dict) model.eval() -model = model.to(device) +print(f"Model loaded to {device}") # ------------------------------------------------------------------------- diff --git a/demo_viser.py b/demo_viser.py index e0211dac..ef9366b5 100644 --- a/demo_viser.py +++ b/demo_viser.py @@ -344,12 +344,12 @@ def main(): print("Initializing and loading VGGT model...") # model = VGGT.from_pretrained("facebook/VGGT-1B") - model = VGGT() + model = VGGT().to(device) _URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" - model.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) - + state_dict = torch.hub.load_state_dict_from_url(_URL, map_location=device) + model.load_state_dict(state_dict) model.eval() - model = model.to(device) + print(f"Model loaded to {device}") # Use the provided image folder path print(f"Loading images from {args.image_folder}...")