diff --git a/birefnet.py b/birefnet.py index 30ff9f5..dab209d 100644 --- a/birefnet.py +++ b/birefnet.py @@ -12,7 +12,7 @@ config = Config() -device = "cuda" if torch.cuda.is_available() else "cpu" +device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" folder_paths.folder_names_and_paths["BiRefNet"] = ([os.path.join(folder_paths.models_dir, "BiRefNet")], folder_paths.supported_pt_extensions) @@ -93,8 +93,7 @@ def remove_background(self, birefnetmodel, image): im_tensor = torch.unsqueeze(im_tensor,0) im_tensor = torch.divide(im_tensor,255.0) im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) - if torch.cuda.is_available(): - im_tensor=im_tensor.cuda() + im_tensor = im_tensor.to(device) result = birefnetmodel(im_tensor)[-1].sigmoid() #print(result.shape) diff --git a/models/backbones/build_backbone.py b/models/backbones/build_backbone.py index 983f04c..5b4ea96 100644 --- a/models/backbones/build_backbone.py +++ b/models/backbones/build_backbone.py @@ -8,6 +8,7 @@ config = Config() +device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" def build_backbone(bb_name, pretrained=True, params_settings=''): if bb_name == 'vgg16': @@ -26,7 +27,7 @@ def build_backbone(bb_name, pretrained=True, params_settings=''): return bb def load_weights(model, model_name): - save_model = torch.load(config.weights[model_name]) + save_model = torch.load(config.weights[model_name], map_location=device) model_dict = model.state_dict() state_dict = {k: v if v.size() == model_dict[k].size() else model_dict[k] for k, v in save_model.items() if k in model_dict.keys()} # to ignore the weights with mismatched size when I modify the backbone itself.