Skip to content

samgeo.hq_sam cuda device required  #173

@oscarbau

Description

@oscarbau

Environment Information

  • samgeo version: 0.9.1
  • Python version: 3.9.16
  • Operating System: Windows

Description

Running the default example to Generate object masks from input prompts with HQ-SAM.

What I Did

import os
import leafmap
from samgeo.hq_sam import SamGeo, tms_to_geotiff
image = "C:/geosam/image.tif"
sam = SamGeo(
    model_type="vit_h",  # can be vit_h, vit_b, vit_l, vit_tiny
    automatic=False,
    sam_kwargs=None,
)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[51], line 1
----> 1 sam = SamGeo(
      2     model_type="vit_h",  # can be vit_h, vit_b, vit_l, vit_tiny
      3     automatic=False,
      4     sam_kwargs=None,
      5 )

File ~\miniconda3\envs\geo\lib\site-packages\samgeo\hq_sam.py:96, in SamGeo.__init__(self, model_type, automatic, device, checkpoint_dir, hq, sam_kwargs, **kwargs)
     93 self.logits = None
     95 # Build the SAM model
---> 96 self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
     97 self.sam.to(device=self.device)
     98 # Use optional arguments for fine-tuning the SAM model

File ~\miniconda3\envs\geo\lib\site-packages\segment_anything_hq\build_sam.py:15, in build_sam_vit_h(checkpoint)
     14 def build_sam_vit_h(checkpoint=None):
---> 15     return _build_sam(
     16         encoder_embed_dim=1280,
     17         encoder_depth=32,
     18         encoder_num_heads=16,
     19         encoder_global_attn_indexes=[7, 15, 23, 31],
     20         checkpoint=checkpoint,
     21     )

File ~\miniconda3\envs\geo\lib\site-packages\segment_anything_hq\build_sam.py:160, in _build_sam(encoder_embed_dim, encoder_depth, encoder_num_heads, encoder_global_attn_indexes, checkpoint)
    158 if checkpoint is not None:
    159     with open(checkpoint, "rb") as f:
--> 160         state_dict = torch.load(f)
    161     info = sam.load_state_dict(state_dict, strict=False)
    162     print(info)

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:712, in load(f, map_location, pickle_module, **pickle_load_args)
    710             opened_file.seek(orig_position)
    711             return torch.jit.load(opened_file)
--> 712         return _load(opened_zipfile, map_location, pickle_module, **pickle_load_args)
    713 return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:1049, in _load(zip_file, map_location, pickle_module, pickle_file, **pickle_load_args)
   1047 unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
   1048 unpickler.persistent_load = persistent_load
-> 1049 result = unpickler.load()
   1051 torch._utils._validate_loaded_sparse_tensors()
   1053 return result

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:1019, in _load.<locals>.persistent_load(saved_id)
   1017 if key not in loaded_storages:
   1018     nbytes = numel * torch._utils._element_size(dtype)
-> 1019     load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
   1021 return loaded_storages[key]

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:1001, in _load.<locals>.load_tensor(dtype, numel, key, location)
    997 storage = zip_file.get_storage_from_record(name, numel, torch._UntypedStorage).storage()._untyped()
    998 # TODO: Once we decide to break serialization FC, we can
    999 # stop wrapping with _TypedStorage
   1000 loaded_storages[key] = torch.storage._TypedStorage(
-> 1001     wrap_storage=restore_location(storage, location),
   1002     dtype=dtype)

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:175, in default_restore_location(storage, location)
    173 def default_restore_location(storage, location):
    174     for _, _, fn in _package_registry:
--> 175         result = fn(storage, location)
    176         if result is not None:
    177             return result

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:152, in _cuda_deserialize(obj, location)
    150 def _cuda_deserialize(obj, location):
    151     if location.startswith('cuda'):
--> 152         device = validate_cuda_device(location)
    153         if getattr(obj, "_torch_load_uninitialized", False):
    154             with torch.cuda.device(device):

File ~\miniconda3\envs\geo\lib\site-packages\torch\serialization.py:136, in validate_cuda_device(location)
    133 device = torch.cuda._utils._get_device_index(location, True)
    135 if not torch.cuda.is_available():
--> 136     raise RuntimeError('Attempting to deserialize object on a CUDA '
    137                        'device but torch.cuda.is_available() is False. '
    138                        'If you are running on a CPU-only machine, '
    139                        'please use torch.load with map_location=torch.device(\'cpu\') '
    140                        'to map your storages to the CPU.')
    141 device_count = torch.cuda.device_count()
    142 if device >= device_count:

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions