@@ -142,7 +142,7 @@ def __init__(self, manager, architecture):
142
142
self .validation_batch_size = 100
143
143
self .print_misclassified_test_images = False
144
144
self .bottleneck_dir = "/tmp/bottleneck"
145
- self .model_dir = "/tmp/imagenet "
145
+ self .model_dir = "./cnn_models/cache "
146
146
self .final_tensor_name = "final_result"
147
147
self .write_logs = False
148
148
@@ -160,7 +160,7 @@ def __init__(self, manager, architecture):
160
160
raise Exception ("Did not recognize architecture flag'" )
161
161
162
162
# Set up the pre-trained graph.
163
- self .maybe_download_and_extract (self .model_info ['data_url' ])
163
+ self .maybe_download_and_extract (self .model_info ['data_url' ], self . model_info [ 'model_dir_name' ] )
164
164
self .graph , self .bottleneck_tensor , self .resized_image_tensor = (
165
165
self .create_model_graph (self .model_info ))
166
166
@@ -517,7 +517,7 @@ def run_bottleneck_on_image(self, sess, image_data, image_data_tensor,
517
517
return bottleneck_values
518
518
519
519
520
- def maybe_download_and_extract (self , data_url ):
520
+ def maybe_download_and_extract (self , data_url , model_dir_name ):
521
521
"""Download and extract model tar file.
522
522
523
523
If the pretrained model we're using doesn't already exist, this function
@@ -526,7 +526,7 @@ def maybe_download_and_extract(self, data_url):
526
526
Args:
527
527
data_url: Web location of the tar file containing the pretrained model.
528
528
"""
529
- dest_directory = self .model_dir
529
+ dest_directory = os . path . join ( self .model_dir , model_dir_name )
530
530
if not os .path .exists (dest_directory ):
531
531
os .makedirs (dest_directory )
532
532
filename = data_url .split ('/' )[- 1 ]
@@ -538,11 +538,10 @@ def _progress(count, block_size, total_size):
538
538
(filename ,
539
539
float (count * block_size ) / float (total_size ) * 100.0 ))
540
540
sys .stdout .flush ()
541
-
542
541
filepath , _ = urllib .request .urlretrieve (data_url , filepath , _progress )
543
542
print ()
544
543
statinfo = os .stat (filepath )
545
- tf .logging .info ('Successfully downloaded' , filename , statinfo .st_size ,
544
+ tf .logging .info ('Successfully downloaded %s %d ' , filename , statinfo .st_size ,
546
545
'bytes.' )
547
546
tarfile .open (filepath , 'r:gz' ).extractall (dest_directory )
548
547
@@ -1084,47 +1083,64 @@ def create_model_info(self, architecture):
1084
1083
tf .logging .error ("Couldn't understand architecture name '%s'" ,
1085
1084
architecture )
1086
1085
return None
1087
- version_string = parts [1 ]
1086
+ v_string = parts [1 ]
1087
+ version_string = parts [2 ]
1088
1088
if (version_string != '1.0' and version_string != '0.75' and
1089
- version_string != '0.50' and version_string != '0.25' ):
1089
+ version_string != '0.50' and version_string != '0.5' and
1090
+ version_string != '0.35' and version_string != '0.25' ):
1090
1091
tf .logging .error (
1091
- """"The Mobilenet version should be '1.0', '0.75', '0.50', or '0.25',
1092
+ """"The Mobilenet version should be '1.0', '0.75', '0.50', '0.35', or '0.25',
1092
1093
but found '%s' for architecture '%s'""" ,
1093
1094
version_string , architecture )
1094
1095
return None
1095
- size_string = parts [2 ]
1096
+ size_string = parts [3 ]
1096
1097
if (size_string != '224' and size_string != '192' and
1097
1098
size_string != '160' and size_string != '128' ):
1098
1099
tf .logging .error (
1099
1100
"""The Mobilenet input size should be '224', '192', '160', or '128',
1100
1101
but found '%s' for architecture '%s'""" ,
1101
1102
size_string , architecture )
1102
1103
return None
1103
- if len (parts ) == 3 :
1104
+ if len (parts ) == 4 :
1104
1105
is_quantized = False
1105
1106
else :
1106
- if parts [3 ] != 'quantized' :
1107
+ if parts [4 ] != 'quantized' :
1107
1108
tf .logging .error (
1108
1109
"Couldn't understand architecture suffix '%s' for '%s'" , parts [3 ],
1109
1110
architecture )
1110
1111
return None
1111
1112
is_quantized = True
1112
- data_url = 'http://download.tensorflow.org/models/mobilenet_v1_'
1113
- data_url += version_string + '_' + size_string + '_frozen.tgz'
1114
- bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
1113
+ data_url = 'http://'
1114
+ model_file_name = None
1115
+ bottleneck_tensor_name = None
1116
+ if architecture .startswith ('mobilenet_v1' ):
1117
+ data_url += 'download.tensorflow.org/models/mobilenet_v1_'
1118
+ data_url += version_string + '_' + size_string + '_frozen.tgz'
1119
+ bottleneck_tensor_name = 'MobilenetV1/Predictions/Reshape:0'
1120
+ if is_quantized :
1121
+ model_base_name = 'quantized_graph.pb'
1122
+ else :
1123
+ model_base_name = 'frozen_graph.pb'
1124
+ model_dir_name = 'mobilenet_v1_'
1125
+ model_dir_name += version_string + '_' + size_string
1126
+ model_file_name = os .path .join (model_dir_name , model_base_name )
1127
+ model_dir_name = ''
1128
+ else :
1129
+ data_url += 'storage.googleapis.com/mobilenet_v2/checkpoints/mobilenet_v2_'
1130
+ data_url += version_string + '_' + size_string + '.tgz'
1131
+ bottleneck_tensor_name = 'MobilenetV2/Predictions/Reshape:0'
1132
+ model_dir_name = 'mobilenet_v2_'
1133
+ model_dir_name += version_string + '_' + size_string
1134
+ model_base_name = model_dir_name + '_frozen.pb'
1135
+ model_file_name = os .path .join (model_dir_name , model_base_name )
1115
1136
bottleneck_tensor_size = 1001
1116
1137
input_width = int (size_string )
1117
1138
input_height = int (size_string )
1118
1139
input_depth = 3
1119
1140
resized_input_tensor_name = 'input:0'
1120
- if is_quantized :
1121
- model_base_name = 'quantized_graph.pb'
1122
- else :
1123
- model_base_name = 'frozen_graph.pb'
1124
- model_dir_name = 'mobilenet_v1_' + version_string + '_' + size_string
1125
- model_file_name = os .path .join (model_dir_name , model_base_name )
1126
1141
input_mean = 127.5
1127
1142
input_std = 127.5
1143
+ print (data_url )
1128
1144
else :
1129
1145
tf .logging .error ("Couldn't understand architecture name '%s'" , architecture )
1130
1146
raise ValueError ('Unknown architecture' , architecture )
@@ -1138,6 +1154,7 @@ def create_model_info(self, architecture):
1138
1154
'input_depth' : input_depth ,
1139
1155
'resized_input_tensor_name' : resized_input_tensor_name ,
1140
1156
'model_file_name' : model_file_name ,
1157
+ 'model_dir_name' : model_dir_name ,
1141
1158
'input_mean' : input_mean ,
1142
1159
'input_std' : input_std ,
1143
1160
}
@@ -1170,11 +1187,3 @@ def add_jpeg_decoding(self, input_width, input_height, input_depth, input_mean,
1170
1187
mul_image = tf .multiply (offset_image , 1.0 / input_std )
1171
1188
return jpeg_data , mul_image
1172
1189
1173
-
1174
- if __name__ == '__main__' :
1175
- cnn_trainer = CNNTrainer ("mobilenet_0.50_128" )
1176
- cnn_trainer .retrain (
1177
- image_dir = "/home/pi/tensorflow/data/applekiwi" ,
1178
- output_graph = "./cnn_models/applewiki_0_5_128.pb" ,
1179
- training_steps = 10 ,
1180
- learning_rate = 0.1 )
0 commit comments