1
1
"""
2
2
Copyright (c) 2018-2021 Intel Corporation
3
-
4
3
Licensed under the Apache License, Version 2.0 (the "License");
5
4
you may not use this file except in compliance with the License.
6
5
You may obtain a copy of the License at
7
-
8
6
http://www.apache.org/licenses/LICENSE-2.0
9
-
10
7
Unless required by applicable law or agreed to in writing, software
11
8
distributed under the License is distributed on an "AS IS" BASIS,
12
9
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
10
See the License for the specific language governing permissions and
14
11
limitations under the License.
15
12
"""
16
13
import math
14
+ import tempfile
15
+ from pathlib import Path
17
16
18
17
import cv2
19
18
import numpy as np
@@ -407,8 +406,8 @@ def configure(self):
407
406
self .color_scale = 255 if not self .normalized_images else 1
408
407
if isinstance (lpips , UnsupportedPackage ):
409
408
lpips .raise_error (self .__provider__ )
410
- self .loss = lpips .LPIPS (net = self .get_value_from_config ('net' ))
411
409
self .dist_threshold = self .get_value_from_config ('distance_threshold' )
410
+ self .loss = self ._create_loss ()
412
411
413
412
def lpips_differ (self , annotation_image , prediction_image ):
414
413
if self .color_order == 'BGR' :
@@ -425,3 +424,37 @@ def evaluate(self, annotations, predictions):
425
424
self .meta ['names' ].append ('ratio_greater_{}' .format (self .dist_threshold ))
426
425
results += (invalid_ratio , )
427
426
return results
427
+
428
+ def _create_loss (self ):
429
+ import torch # pylint: disable=C0415
430
+ import torchvision # pylint: disable=C0415
431
+ net = self .get_value_from_config ('net' )
432
+ model_weights = {
433
+ 'alex' : ('https://download.pytorch.org/models/alexnet-owt-7be5be79.pth' ,
434
+ 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth'
435
+ ),
436
+ 'squeeze' : 'https://download.pytorch.org/models/squeezenet1_1-b8a52dc0.pth' ,
437
+ 'vgg' : 'https://download.pytorch.org/models/vgg16-397923af.pth'
438
+ }
439
+ model_classes = {
440
+ 'alex' : torchvision .models .alexnet ,
441
+ 'squeeze' : torchvision .models .squeezenet1_1 ,
442
+ 'vgg' : torchvision .models .vgg16
443
+ }
444
+ with tempfile .TemporaryDirectory (prefix = 'lpips_model' , dir = Path .cwd ()) as model_dir :
445
+ weights = model_weights [net ]
446
+ if isinstance (weights , tuple ):
447
+ weights = weights [1 ] if torch .__version__ <= '1.6.0' else weights [0 ]
448
+ preloaded_weights = torch .utils .model_zoo .load_url (
449
+ weights , model_dir = model_dir , progress = False , map_location = 'cpu'
450
+ )
451
+ model = model_classes [net ](pretrained = False )
452
+ model .load_state_dict (preloaded_weights )
453
+ feats = model .features
454
+ loss = lpips .LPIPS (pnet_rand = True )
455
+ for slice_id in range (1 , loss .net .N_slices + 1 ):
456
+ sl = getattr (loss .net , 'slice{}' .format (slice_id ))
457
+ for module_id in sl ._modules : # pylint: disable=W0212
458
+ sl ._modules [module_id ] = feats [int (module_id )] # pylint: disable=W0212
459
+ setattr (loss .net , 'slice{}' .format (slice_id ), sl )
460
+ return loss
0 commit comments