Skip to content

Commit c744447

Browse files
fix: Use safetensor models for PBR maps instead of pickles.
1 parent 5d691ea commit c744447

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

invokeai/backend/image_util/pbr_maps/pbr_maps.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,35 @@
22
# Adopted and optimized for Invoke AI
33

44
import pathlib
5-
from typing import Any, Literal, OrderedDict
5+
from typing import Any, Literal
66

77
import cv2
88
import numpy as np
99
import numpy.typing as npt
1010
import torch
1111
from PIL import Image
12-
from torch.serialization import add_safe_globals
12+
from safetensors.torch import load_file
1313

1414
from invokeai.backend.image_util.pbr_maps.architecture.pbr_rrdb_net import PBR_RRDB_Net
1515
from invokeai.backend.image_util.pbr_maps.utils.image_ops import crop_seamless, esrgan_launcher_split_merge
1616

17-
NORMAL_MAP_MODEL = "https://github.com/joeyballentine/Material-Map-Generator/blob/master/utils/models/1x_NormalMapGenerator-CX-Lite_200000_G.pth"
18-
OTHER_MAP_MODEL = "https://github.com/joeyballentine/Material-Map-Generator/blob/master/utils/models/1x_FrankenMapGenerator-CX-Lite_215000_G.pth"
17+
NORMAL_MAP_MODEL = (
18+
"https://huggingface.co/InvokeAI/pbr-material-maps/resolve/main/normal_map_generator.safetensors?download=true"
19+
)
20+
OTHER_MAP_MODEL = (
21+
"https://huggingface.co/InvokeAI/pbr-material-maps/resolve/main/franken_map_generator.safetensors?download=true"
22+
)
1923

2024

2125
class PBRMapsGenerator:
2226
def __init__(self, normal_map_model: PBR_RRDB_Net, other_map_model: PBR_RRDB_Net, device: torch.device) -> None:
2327
self.normal_map_model = normal_map_model
2428
self.other_map_model = other_map_model
2529
self.device = device
26-
add_safe_globals([PBR_RRDB_Net, OrderedDict])
2730

2831
@staticmethod
2932
def load_model(model_path: pathlib.Path, device: torch.device) -> PBR_RRDB_Net:
30-
state_dict = torch.load(model_path.as_posix(), map_location="cpu")
33+
state_dict = load_file(model_path.as_posix(), device=device.type)
3134

3235
model = PBR_RRDB_Net(
3336
3,

0 commit comments

Comments
 (0)