diff --git a/contrib/kl_layer/README.md b/contrib/kl_layer/README.md new file mode 100644 index 0000000000..5d306ae69a --- /dev/null +++ b/contrib/kl_layer/README.md @@ -0,0 +1,18 @@ +This folder contains the implementation of custom KL divergence layer. +This is a custom implementation and not a built-in layer in any deep learning framework. +It was developed specifically for [AD@L1 CMS paper](https://www.nature.com/articles/s42256-022-00441-3). + +# Files + +* `kl_layer.py`: contains the standalone implementation of the custom KL divergence layer +* `kl_layer.h`: contains the HLS implementation of KL layer + + +# Usage + +`kl_layer.py` contains the example of how to use the KL layer. +To run do + +``` +python kl_layer.py +``` diff --git a/contrib/kl_layer/kl_layer.h b/contrib/kl_layer/kl_layer.h new file mode 100644 index 0000000000..0435b9a22e --- /dev/null +++ b/contrib/kl_layer/kl_layer.h @@ -0,0 +1,87 @@ +#ifndef KL_LAYER_H_ +#define KL_LAYER_H_ + +#include "nnet_activation.h" +#include "nnet_common.h" +#include +#include + +namespace nnet { + +struct distance_config { + // IO size + static const unsigned n_in = 10; + static const unsigned n_out = 1; + + // Internal data type definitions + typedef float accum_t; + typedef float sum_t; + typedef ap_fixed<18, 8> exp_table_t; + + // Internal info + static const unsigned table_size = 1024; + static constexpr unsigned exp_range = 8; +}; + +template void init_klloss_exp_table(typename CONFIG_T::exp_table_t table_out[N_TABLE]) { + for (int ii = 0; ii < N_TABLE; ii++) { + // First, convert from table index to X-value (range -1 to +1) + float in_val = 2 * CONFIG_T::exp_range * (ii - float(N_TABLE) / 2.0) / float(N_TABLE); + // Next, compute lookup table function + typename CONFIG_T::exp_table_t real_val = exp_fcn_float(in_val); + // std::cout << "Lookup table In Value: " << in_val << " Result: " << real_val << " Index: " << ii << std::endl; + table_out[ii] = real_val; + } +} +template +void klloss(data1_T mean[CONFIG_T::n_in], data2_T log_var[CONFIG_T::n_in], res_T res[CONFIG_T::n_out]) { + #pragma HLS PIPELINE + // Initialize the lookup tables +#ifdef __HLS_SYN__ + bool initialized = false; + typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; +#else + static bool initialized = false; + static typename CONFIG_T::exp_table_t exp_table[CONFIG_T::table_size]; +#endif + if (!initialized) { + init_klloss_exp_table(exp_table); + initialized = true; + } + typename CONFIG_T::accum_t kl[CONFIG_T::n_in]; + #pragma HLS ARRAY_PARTITION variable=kl complete + typename CONFIG_T::accum_t mean_sq[CONFIG_T::n_in]; + #pragma HLS ARRAY_PARTITION variable=mean_sq complete + typename CONFIG_T::accum_t kl_sum(0); + for (unsigned i = 0; i < CONFIG_T::n_in; i++) { + #pragma HLS UNROLL + mean_sq[i] = mean[i] * mean[i]; + kl[i] = data2_T(1.) + log_var[i]; + // std::cout << "Log var: " << log_var[i] << " Result: " << kl[i] << std::endl; + } + constexpr unsigned table_scale = (unsigned)(CONFIG_T::table_size / (2 * CONFIG_T::exp_range)); + constexpr unsigned index_scale = (unsigned)(CONFIG_T::exp_range * table_scale); + for (unsigned i = 0; i < CONFIG_T::n_in; i++) { + #pragma HLS UNROLL + auto data_round = log_var[i] * table_scale; + auto index = data_round + index_scale; + if (index < 0) + index = 0; + if (index > CONFIG_T::table_size - 1) + index = CONFIG_T::table_size - 1; + kl[i] -= exp_table[index]; + // std::cout << "Exp var: " << exp_table[index] << " Result: " << kl[i] << " Index: " << index << std::endl; + } + for (unsigned i = 0; i < CONFIG_T::n_in; i++) { + #pragma HLS UNROLL + kl[i] -= mean_sq[i]; + } + Op_add op_add; + kl_sum = reduce>(kl, op_add); + // std::cout << "KL sum: " << kl_sum << std::endl; + kl_sum *= typename CONFIG_T::accum_t(1. / CONFIG_T::n_in); + res[0] = res_T(-0.5) * kl_sum; +} +} // namespace nnet + +#endif diff --git a/contrib/kl_layer/kl_layer.py b/contrib/kl_layer/kl_layer.py new file mode 100644 index 0000000000..ec2af1b797 --- /dev/null +++ b/contrib/kl_layer/kl_layer.py @@ -0,0 +1,185 @@ +""" + Usage example for a custom KL loss layer + Takes as an input two arrays: z_mean and z_log_var + and computes KL "distance" between normal distribution + and Gaussian with mu=z_mean and sigma=z_log_var + + The HLS part is in contrib/kl_layer/kl_layer.h +""" +from pathlib import Path + +import numpy as np +import tensorflow as tf + +try: + from keras.layers.merge import _Merge as Merge +except Exception: + from keras.layers.merging.base_merge import _Merge as Merge + +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import math_ops + +import hls4ml +from hls4ml.converters.keras_to_hls import parse_default_keras_layer +from hls4ml.model.attributes import ConfigurableAttribute, TypeAttribute +from hls4ml.model.types import FixedPrecisionType, RoundingMode, SaturationMode + + +# Keras implementation of a KL layer +class KLLoss(Merge): + '''Keras implementation of a KL loss custom layer''' + + @tf_utils.shape_type_conversion + def build(self, input_shape): + super().build(input_shape) + + def _merge_function(self, inputs): + + mean = inputs[0] + log_var = inputs[1] + + kl = 1.0 + log_var - math_ops.square(mean) - math_ops.exp(log_var) + kl = -0.5 * math_ops.reduce_mean(kl, axis=-1, keepdims=True) + + return kl + + +# hls4ml implementations +class HKLLoss(hls4ml.model.layers.Layer): + '''hls4ml implementation of a KL loss custom layer''' + + _expected_attributes = [ + ConfigurableAttribute('table_size', default=1024), + ConfigurableAttribute('exp_range', default=8), + TypeAttribute('accum'), + TypeAttribute( + 'sum', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + TypeAttribute( + 'exp_table', + default=FixedPrecisionType(18, 8, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT), + ), + ] + + def initialize(self): + self.add_output_variable(shape=[1], dim_names=[f'KL_LOSS_{self.index}']) + + +# Templates +distance_config_template = """struct config{index} : nnet::distance_config {{ + static const unsigned n_in = {n_in}; + static const unsigned n_out = 1; + typedef {accum_t.name} accum_t; + typedef {sum_t.name} sum_t; + typedef {exp_table_t.name} exp_table_t; + static const unsigned table_size = {table_size}; + static constexpr float exp_range = {exp_range}; +}};\n""" +distance_function_template = 'nnet::klloss<{input1_t}, {input2_t}, {output_t}, {config}>({input1}, {input2}, {output});' +distance_include_list = ['nnet_utils/kl_layer.h'] + + +class HKLLossConfigTemplate(hls4ml.backends.template.LayerConfigTemplate): + def __init__(self): + super().__init__(HKLLoss) + self.template = distance_config_template + + def format(self, node): + params = self._default_config_params(node) + params['n_in'] = node.get_input_variable(node.inputs[0]).shape[0] + params['n_out'] = 1 + return self.template.format(**params) + + +class HKLLossFunctionTemplate(hls4ml.backends.template.FunctionCallTemplate): + def __init__(self): + super().__init__(HKLLoss, include_header=distance_include_list) + self.template = distance_function_template + + def format(self, node): + params = {} + params['config'] = f'config{node.index}' + params['input1_t'] = node.get_input_variable(node.inputs[0]).type.name + params['input2_t'] = node.get_input_variable(node.inputs[1]).type.name + params['output_t'] = node.get_output_variable().type.name + params['input1'] = node.get_input_variable(node.inputs[0]).name + params['input2'] = node.get_input_variable(node.inputs[1]).name + params['output'] = node.get_output_variable().name + + return self.template.format(**params) + + +# Parser for converter +def parse_klloss_layer(keras_layer, input_names, input_shapes, data_reader): + assert 'KLLoss' in keras_layer['class_name'] + + layer = parse_default_keras_layer(keras_layer, input_names) + + output_shape = [input_shapes[0][0], 1] + + return layer, output_shape + + +def main(): + # Register the converter for custom Keras layer + hls4ml.converters.register_keras_layer_handler('KLLoss', parse_klloss_layer) + + # Register the hls4ml's IR layer + hls4ml.model.layers.register_layer('KLLoss', HKLLoss) + + # Register the optimization passes (if any) + backend = hls4ml.backends.get_backend('Vivado') + + # Register template passes for the given backend + backend.register_template(HKLLossConfigTemplate) + backend.register_template(HKLLossFunctionTemplate) + + # Register HLS implementation + p = Path(__file__).parent / 'kl_layer.h' + backend.register_source(p) + + # Test if it works + # Create a dummy Keras model with KL loss layer + inp = tf.keras.layers.Input(shape=(19, 3, 1)) + z_mean = tf.keras.layers.Dense(10)(inp) + z_log_var = tf.keras.layers.Dense(10)(inp) + custom_output = KLLoss()([z_mean, z_log_var]) + # create new model + kmodel = tf.keras.models.Model(inputs=inp, outputs=custom_output) + kmodel.summary() + + # test on random inputs + x = np.random.randint(-5, 5, (1, 19, 3, 1), dtype='int32') + kres = kmodel(x) + + # Create dummy config + config = {} + config['Model'] = { + 'Precision': 'ap_fixed<16,6>', + 'ReuseFactor': 1, + 'ParallelizationFactor': 1, + 'Strategy': 'Resource', + } + hmodel = hls4ml.converters.convert_from_keras_model( + kmodel, + output_dir='hls4mlprj_kl_layer', + backend='Vivado', + io_type='io_parallel', + part='xcvu9p-flga2577-2-e', + hls_config=config, + ) + + hmodel.compile() + hres = hmodel.predict(x.astype('float32')) + + print('Compare prediction by hls4ml model to Keras one') + print(kres - hres) + + print('Building model') + report = hmodel.build(reset=True, csim=False, cosim=True, synth=True, vsynth=True) + print(report) + + +if __name__ == '__main__': + main() diff --git a/hls4ml/backends/vivado/passes/core_templates.py b/hls4ml/backends/vivado/passes/core_templates.py index f63c0f454d..74e556ddf4 100644 --- a/hls4ml/backends/vivado/passes/core_templates.py +++ b/hls4ml/backends/vivado/passes/core_templates.py @@ -1,6 +1,6 @@ from hls4ml.backends.backend import get_backend -from hls4ml.model.layers import Activation, BatchNormalization, Dense, Embedding, PReLU, ParametrizedActivation, Softmax +from hls4ml.model.layers import Activation, BatchNormalization, Dense, DenseBatchnorm, Embedding, PReLU, ParametrizedActivation, Softmax from hls4ml.backends.template import LayerConfigTemplate, FunctionCallTemplate # Dense templates @@ -28,7 +28,7 @@ class DenseConfigTemplate(LayerConfigTemplate): def __init__(self): - super().__init__(Dense) + super().__init__((Dense, DenseBatchnorm)) self.template = dense_config_template def format(self, node): @@ -41,7 +41,7 @@ def format(self, node): class DenseFunctionTemplate(FunctionCallTemplate): def __init__(self): - super().__init__(Dense, include_header=dense_include_list) + super().__init__((Dense, DenseBatchnorm), include_header=dense_include_list) self.template = dense_function_template def format(self, node): diff --git a/hls4ml/backends/vivado/vivado_backend.py b/hls4ml/backends/vivado/vivado_backend.py index b1b586f6c4..793a1d24be 100644 --- a/hls4ml/backends/vivado/vivado_backend.py +++ b/hls4ml/backends/vivado/vivado_backend.py @@ -189,13 +189,13 @@ def build( curr_dir = os.getcwd() os.chdir(model.config.get_output_dir()) vivado_cmd = ( - f'vivado_hls -f build_prj.tcl "reset={reset}' - f'csim={csim}' - f'synth={synth}' - f'cosim={cosim}' - f'validation={validation}' - f'export={export}' - f'vsynth={vsynth}' + f'vivado_hls -f build_prj.tcl "reset={reset} ' + f'csim={csim} ' + f'synth={synth} ' + f'cosim={cosim} ' + f'validation={validation} ' + f'export={export} ' + f'vsynth={vsynth} ' f'fifo_opt={fifo_opt}"' ) os.system(vivado_cmd) diff --git a/hls4ml/converters/keras/core.py b/hls4ml/converters/keras/core.py index 4411ae4c53..9f51a6a67d 100644 --- a/hls4ml/converters/keras/core.py +++ b/hls4ml/converters/keras/core.py @@ -113,7 +113,7 @@ def parse_activation_layer(keras_layer, input_names, input_shapes, data_reader): @keras_handler('BatchNormalization') def parse_batchnorm_layer(keras_layer, input_names, input_shapes, data_reader): - assert 'BatchNormalization' in keras_layer['class_name'] or 'QConv2DBatchnorm' in keras_layer['class_name'] + assert('BatchNormalization' in keras_layer['class_name'] or 'QConv2DBatchnorm' in keras_layer['class_name'] or 'QDenseBatchnorm' in keras_layer['class_name']) layer = parse_default_keras_layer(keras_layer, input_names) diff --git a/hls4ml/converters/keras/qkeras_layers.py b/hls4ml/converters/keras/qkeras_layers.py index 5839ca542a..eba44c43ea 100644 --- a/hls4ml/converters/keras/qkeras_layers.py +++ b/hls4ml/converters/keras/qkeras_layers.py @@ -114,3 +114,13 @@ def parse_qconv2dbatchnorm_layer(keras_layer, input_names, input_shapes, data_re temp_shape = intermediate_shape batch_layer, out_shape = parse_batchnorm_layer(keras_layer, input_names, temp_shape, data_reader) return {**conv_layer, **batch_layer}, out_shape + +@keras_handler('QDenseBatchnorm') +def parse_qdensebatchnorm_layer(keras_layer, input_names, input_shapes, data_reader): + intermediate_shape = list() + dense_layer, shape_qdense = parse_qdense_layer(keras_layer, input_names, input_shapes, data_reader) + intermediate_shape.append(shape_qdense) + temp_shape = intermediate_shape + batch_layer, out_shape = parse_batchnorm_layer(keras_layer, input_names, temp_shape, data_reader) + batch_layer.pop('n_in') + return {**dense_layer, **batch_layer}, out_shape diff --git a/hls4ml/model/layers.py b/hls4ml/model/layers.py index b8a3a1a4d9..1d3a40014c 100644 --- a/hls4ml/model/layers.py +++ b/hls4ml/model/layers.py @@ -394,6 +394,51 @@ def initialize(self): self.add_bias(quantizer=self.get_attr('bias_quantizer')) +class DenseBatchnorm(Dense): + def _get_folded_weights(self): + """ + Function to get the batchnorm folded weights. + This function converts the weights by folding batchnorm parameters into + the weight of QDense. The high-level equation: + W_fold = gamma * W / sqrt(variance + epsilon) + bias_fold = gamma * (bias - moving_mean) / sqrt(variance + epsilon) + beta + """ + kernel = self.model.get_weights_data(self.name, 'kernel') + bias = self.model.get_weights_data(self.name, 'bias') + if bias is None: + bias = 0 + + # get batchnorm weights and moving stats + gamma = self.model.get_weights_data(self.name, 'gamma') + beta = self.model.get_weights_data(self.name, 'beta') + moving_mean = self.model.get_weights_data(self.name, 'moving_mean') + moving_variance = self.model.get_weights_data(self.name, 'moving_variance') + # get the inversion factor so that we replace division by multiplication + inv = np.reciprocal(np.sqrt(moving_variance + self.get_attr('epsilon'))) + if gamma is not None: + inv *= gamma + + # wrap conv kernel and bias with bn parameters + folded_kernel = inv * kernel + folded_bias = inv * (bias - moving_mean) + beta + + return [folded_kernel, folded_bias] + + def initialize(self): + super(DenseBatchnorm, self).initialize() + folded_weights, folded_bias = self._get_folded_weights() + if self.model.config.is_resource_strategy(self) and self.model.config.backend.name in ['Vivado', 'VivadoAccelerator']: + self.weights['weight'].data_unquantized = np.transpose(folded_weights) + self.weights['weight'].data = self.get_attr('weight_quantizer')(self.weights['weight'].data_unquantized) + + else: + self.weights['weight'].data_unquantized = folded_weights + self.weights['weight'].data = self.get_attr('weight_quantizer')(folded_weights) + self.weights['bias'].data_unquantized = folded_bias + bias_q = self.get_attr('bias_quantizer') + if bias_q is not None: + self.weights['bias'].data = bias_q(folded_bias) + class Conv1D(Layer): _expected_attributes = [ Attribute('in_width'), @@ -1269,6 +1314,7 @@ def _initialize_transforms(self): 'BinaryDense': Dense, 'TernaryDense': Dense, 'QDense': Dense, + 'QDenseBatchnorm': DenseBatchnorm, 'Conv1D': Conv1D, 'QConv1D': Conv1D, 'Conv2D': Conv2D,