@@ -313,35 +313,36 @@ def _setup_transforms(self, n_aggregators, n_filters, n_propagate):
313313 name = ('Fout%d' % it ),
314314 )
315315
316- # Check for correctness. This commented out because pre-commit showed it was unused.
317-
318- # if self._output_activation is None or self._output_activation == "linear":
319- # output_activation_transform = (QActivation("quantized_bits(%i, %i)"
320- # % (self._total_bits, self._int_bits)))
321- # else:
322- # output_activation_transform = QActivation(
323- # "quantized_%s(%i, %i)" % (self._output_activation, self._total_bits, self._int_bits)
324- # )
316+ if self ._output_activation is None or self ._output_activation == "linear" :
317+ output_activation_transform = QActivation ("quantized_bits(%i, %i)" % (self ._total_bits , self ._int_bits ))
318+ else :
319+ output_activation_transform = QActivation (
320+ "quantized_%s(%i, %i)" % (self ._output_activation , self ._total_bits , self ._int_bits )
321+ )
325322 else :
326323 input_feature_transform = NamedDense (p , name = ('FLR%d' % it ))
327324 output_feature_transform = NamedDense (f , name = ('Fout%d' % it ))
328- # output_activation_transform = keras.layers.Activation(self._output_activation)
325+ output_activation_transform = keras .layers .Activation (self ._output_activation )
329326
330327 aggregator_distance = NamedDense (a , name = ('S%d' % it ))
331328
332- self ._transform_layers .append ((input_feature_transform , aggregator_distance , output_feature_transform ))
329+ self ._transform_layers .append (
330+ (input_feature_transform , aggregator_distance , output_feature_transform , output_activation_transform )
331+ )
333332
334333 self ._sublayers = sum ((list (layers ) for layers in self ._transform_layers ), [])
335334
336335 def _build_transforms (self , data_shape ):
337- for in_transform , d_compute , out_transform in self ._transform_layers :
336+ for in_transform , d_compute , out_transform , act_transform in self ._transform_layers :
338337 in_transform .build (data_shape )
339338 d_compute .build (data_shape )
340339 if self ._simplified :
341- out_transform .build (data_shape [:2 ] + (d_compute .units * in_transform .units ,))
340+ act_transform . build ( out_transform .build (data_shape [:2 ] + (d_compute .units * in_transform .units ,) ))
342341 else :
343- out_transform .build (
344- data_shape [:2 ] + (data_shape [2 ] + d_compute .units * in_transform .units + d_compute .units ,)
342+ act_transform .build (
343+ out_transform .build (
344+ data_shape [:2 ] + (data_shape [2 ] + d_compute .units * in_transform .units + d_compute .units ,)
345+ )
345346 )
346347
347348 data_shape = data_shape [:2 ] + (out_transform .units ,)
0 commit comments