|
1 | 1 | from dataclasses import dataclass |
2 | 2 |
|
3 | | -from torchaudio._internal import load_state_dict_from_url |
| 3 | +import torch |
| 4 | +import torchaudio |
4 | 5 |
|
5 | 6 | from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective |
6 | 7 |
|
@@ -42,26 +43,16 @@ class SquimObjectiveBundle: |
42 | 43 | _path: str |
43 | 44 | _sample_rate: float |
44 | 45 |
|
45 | | - def _get_state_dict(self, dl_kwargs): |
46 | | - url = f"https://download.pytorch.org/torchaudio/models/{self._path}" |
47 | | - dl_kwargs = {} if dl_kwargs is None else dl_kwargs |
48 | | - state_dict = load_state_dict_from_url(url, **dl_kwargs) |
49 | | - return state_dict |
50 | | - |
51 | | - def get_model(self, *, dl_kwargs=None) -> SquimObjective: |
| 46 | + def get_model(self) -> SquimObjective: |
52 | 47 | """Construct the SquimObjective model, and load the pretrained weight. |
53 | 48 |
|
54 | | - The weight file is downloaded from the internet and cached with |
55 | | - :func:`torch.hub.load_state_dict_from_url` |
56 | | -
|
57 | | - Args: |
58 | | - dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. |
59 | | -
|
60 | 49 | Returns: |
61 | 50 | Variation of :py:class:`~torchaudio.models.SquimObjective`. |
62 | 51 | """ |
63 | 52 | model = squim_objective_base() |
64 | | - model.load_state_dict(self._get_state_dict(dl_kwargs)) |
| 53 | + path = torchaudio.utils.download_asset(f"models/{self._path}") |
| 54 | + state_dict = torch.load(path, weights_only=True) |
| 55 | + model.load_state_dict(state_dict) |
65 | 56 | model.eval() |
66 | 57 | return model |
67 | 58 |
|
@@ -128,26 +119,15 @@ class SquimSubjectiveBundle: |
128 | 119 | _path: str |
129 | 120 | _sample_rate: float |
130 | 121 |
|
131 | | - def _get_state_dict(self, dl_kwargs): |
132 | | - url = f"https://download.pytorch.org/torchaudio/models/{self._path}" |
133 | | - dl_kwargs = {} if dl_kwargs is None else dl_kwargs |
134 | | - state_dict = load_state_dict_from_url(url, **dl_kwargs) |
135 | | - return state_dict |
136 | | - |
137 | | - def get_model(self, *, dl_kwargs=None) -> SquimSubjective: |
| 122 | + def get_model(self) -> SquimSubjective: |
138 | 123 | """Construct the SquimSubjective model, and load the pretrained weight. |
139 | | -
|
140 | | - The weight file is downloaded from the internet and cached with |
141 | | - :func:`torch.hub.load_state_dict_from_url` |
142 | | -
|
143 | | - Args: |
144 | | - dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`. |
145 | | -
|
146 | 124 | Returns: |
147 | 125 | Variation of :py:class:`~torchaudio.models.SquimObjective`. |
148 | 126 | """ |
149 | 127 | model = squim_subjective_base() |
150 | | - model.load_state_dict(self._get_state_dict(dl_kwargs)) |
| 128 | + path = torchaudio.utils.download_asset(f"models/{self._path}") |
| 129 | + state_dict = torch.load(path, weights_only=True) |
| 130 | + model.load_state_dict(state_dict) |
151 | 131 | model.eval() |
152 | 132 | return model |
153 | 133 |
|
|
0 commit comments