File tree Expand file tree Collapse file tree 1 file changed +13
-3
lines changed
models/public/vitstr-small-patch16-224 Expand file tree Collapse file tree 1 file changed +13
-3
lines changed Original file line number Diff line number Diff line change 12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
- from torch import load
15
+ import torch
16
16
from modules .vitstr import vitstr_small_patch16_224
17
17
18
18
19
+ class Model (torch .nn .Module ):
20
+ def __init__ (self , model ):
21
+ super (Model , self ).__init__ ()
22
+ self .model = model
23
+
24
+ def forward (self , x ):
25
+ return self .model (x )
26
+
27
+
28
+
19
29
def create_model (weights ):
20
30
model = vitstr_small_patch16_224 (num_classes = 96 )
21
31
22
- checkpoint = load (weights , map_location = 'cpu' )
32
+ checkpoint = torch . load (weights , map_location = 'cpu' )
23
33
ckpt = {k .replace ('module.vitstr.' , '' ): v for k , v in checkpoint .items ()}
24
34
25
35
model .load_state_dict (ckpt )
26
36
27
- return model
37
+ return Model ( model )
You can’t perform that action at this time.
0 commit comments