-
-
Notifications
You must be signed in to change notification settings - Fork 352
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Environment Information
- samgeo version: 0.9.1
- Python version: 3.9.16
- Operating System: Windows
Description
Running the default example to Generate object masks from input prompts with HQ-SAM.
What I Did
import os
import leafmap
from samgeo.hq_sam import SamGeo, tms_to_geotiff
image = "C:/geosam/image.tif"
sam = SamGeo(
model_type="vit_h", # can be vit_h, vit_b, vit_l, vit_tiny
automatic=False,
sam_kwargs=None,
)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[51], line 1
----> 1 sam = SamGeo(
2 model_type="vit_h", # can be vit_h, vit_b, vit_l, vit_tiny
3 automatic=False,
4 sam_kwargs=None,
5 )
File ~\miniconda3\envs\geo\lib\site-packages\samgeo\hq_sam.py:96, in SamGeo.__init__(self, model_type, automatic, device, checkpoint_dir, hq, sam_kwargs, **kwargs)
93 self.logits = None
95 # Build the SAM model
---> 96 self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
97 self.sam.to(device=self.device)
98 # Use optional arguments for fine-tuning the SAM model
File ~\miniconda3\envs\geo\lib\site-packages\segment_anything_hq\build_sam.py:15, in build_sam_vit_h(checkpoint)
14 def build_sam_vit_h(checkpoint=None):
---> 15 return _build_sam(
16 encoder_embed_dim=1280,
17 encoder_depth=32,
18 encoder_num_heads=16,
19 encoder_global_attn_indexes=[7, 15, 23, 31],
20 checkpoint=checkpoint,
21 )
File ~\miniconda3\envs\geo\lib\site-packages\segment_anything_hq\build_sam.py:160, in _build_sam(encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint)
158 if checkpoint is not None:
159 with open(checkpoint, "rb") as f:
--> 160 state_dict = torch.load(f)
161 info = sam.load_state_dict(state_dict, strict=False)
162 print(info)
File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:712, in load(f, map_location, pickle_module, **pickle_load_args)
710 opened_file.seek(orig_position)
711 return torch.jit.load(opened_file)
--> 712 return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:1049, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
1047 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
1048 unpickler.persistent_load = persistent_load
-> 1049 result = unpickler.load()
1051 torch._utils._validate_loaded_sparse_tensors()
1053 return result
File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:1019, in _load.<locals>.persistent_load(saved_id)
1017 if key not in loaded_storages:
1018 nbytes = numel * torch._utils._element_size(dtype)
-> 1019 load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
1021 return loaded_storages[key]
File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:1001, in _load.<locals>.load_tensor(dtype, numel, key, location)
997 storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()
998 # TODO: Once we decide to break serialization FC, we can
999 # stop wrapping with _TypedStorage
1000 loaded_storages[key] = torch.storage._TypedStorage(
-> 1001 wrap_storage=restore_location(storage, location),
1002 dtype=dtype)
File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:175, in default_restore_location(storage, location)
173 def default_restore_location(storage, location):
174 for _, _, fn in _package_registry:
--> 175 result = fn(storage, location)
176 if result is not None:
177 return result
File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:152, in _cuda_deserialize(obj, location)
150 def _cuda_deserialize(obj, location):
151 if location.startswith('cuda'):
--> 152 device = validate_cuda_device(location)
153 if getattr(obj, "_torch_load_uninitialized", False):
154 with torch.cuda.device(device):
File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:136, in validate_cuda_device(location)
133 device = torch.cuda._utils._get_device_index(location, True)
135 if not torch.cuda.is_available():
--> 136 raise RuntimeError('Attempting to deserialize object on a CUDA '
137 'device but torch.cuda.is_available() is False. '
138 'If you are running on a CPU-only machine, '
139 'please use torch.load with map_location=torch.device(\'cpu\') '
140 'to map your storages to the CPU.')
141 device_count = torch.cuda.device_count()
142 if device >= device_count:
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working