9999]
100100
101101
102- def find_bn_fusing_layer_pair (model ):
102+ def find_bn_fusing_layer_pair (model , custom_objects = {} ):
103103 """Finds layers that can be fused with the following batchnorm layers.
104104
105105 Args:
106106 model: input model
107+ custom_objects: Dict of model specific objects needed for cloning.
107108
108109 Returns:
109110 Dict that marks all the layer pairs that need to be fused.
110111
111112 Note: supports sequential and non-sequential model
112113 """
113114
114- fold_model = clone_model (model )
115+ fold_model = clone_model (model , custom_objects )
115116 (graph , _ ) = qgraph .GenerateGraphFromModel (
116117 fold_model , "quantized_bits(8, 0, 1)" , "quantized_bits(8, 0, 1)" )
117118
@@ -219,7 +220,7 @@ def apply_quantizer(quantizer, input_weight):
219220
220221
221222# Model utilities: before saving the weights, we want to apply the quantizers
222- def model_save_quantized_weights (model , filename = None ):
223+ def model_save_quantized_weights (model , filename = None , custom_objects = {} ):
223224 """Quantizes model for inference and save it.
224225
225226 Takes a model with weights, apply quantization function to weights and
@@ -241,17 +242,19 @@ def model_save_quantized_weights(model, filename=None):
241242 model: model with weights to be quantized.
242243 filename: if specified, we will save the hdf5 containing the quantized
243244 weights so that we can use them for inference later on.
245+ custom_objects: Dict of model specific objects needed to load/store.
244246
245247 Returns:
246248 dictionary containing layer name and quantized weights that can be used
247249 by a hardware generator.
248-
249250 """
250251
251252 saved_weights = {}
252253
253254 # Find the conv/dense layers followed by Batchnorm layers
254- (fusing_layer_pair_dict , bn_layers_to_skip ) = find_bn_fusing_layer_pair (model )
255+ (fusing_layer_pair_dict , bn_layers_to_skip ) = find_bn_fusing_layer_pair (
256+ model , custom_objects
257+ )
255258
256259 print ("... quantizing model" )
257260 for layer in model .layers :
0 commit comments