Skip to content

Commit e044a12

Browse files
authored
Merge branch 'main' into hls4ml-optimization-api-part-1
2 parents 7c2d128 + d36e226 commit e044a12

File tree

17 files changed

+251
-99
lines changed

17 files changed

+251
-99
lines changed

.gitlab-ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ generator:
77
stage: generate
88
image: python:3.8-alpine
99
tags:
10-
- docker
10+
- k8s-default
1111
before_script:
1212
- pip install pyyaml
1313
script:

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ exclude: (^hls4ml\/templates\/(vivado|quartus)\/(ap_types|ac_types)\/|^test/pyte
22

33
repos:
44
- repo: https://github.com/psf/black
5-
rev: 23.7.0
5+
rev: 23.9.1
66
hooks:
77
- id: black
88
language_version: python3
@@ -30,13 +30,13 @@ repos:
3030
args: ["--profile", "black", --line-length=125]
3131

3232
- repo: https://github.com/asottile/pyupgrade
33-
rev: v3.10.1
33+
rev: v3.14.0
3434
hooks:
3535
- id: pyupgrade
3636
args: ["--py36-plus"]
3737

3838
- repo: https://github.com/asottile/setup-cfg-fmt
39-
rev: v2.4.0
39+
rev: v2.5.0
4040
hooks:
4141
- id: setup-cfg-fmt
4242

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,8 @@ binary/ternary networks:
135135
year = "2021"
136136
}
137137
```
138+
139+
# Acknowledgments
140+
If you benefited from participating in our community, we ask that you please acknowledge the Fast Machine Learning collaboration, and particular individuals who helped you, in any publications.
141+
Please use the following text for this acknowledgment:
142+
> We acknowledge the Fast Machine Learning collective as an open community of multi-domain experts and collaborators. This community and \<names of individuals\>, in particular, were important for the development of this project.

docs/reference.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,13 @@ binary/ternary networks:
8686
year = "2021"
8787
}
8888
89+
Acknowledgments
90+
===============
91+
If you benefited from participating in our community, we ask that you please acknowledge the Fast Machine Learning collaboration, and particular individuals who helped you, in any publications.
92+
Please use the following text for this acknowledgment:
93+
We acknowledge the Fast Machine Learning collective as an open community of multi-domain experts and collaborators. This community and \<names of individuals\>, in particular, were important for the development of this project.
94+
95+
8996
Contributors
9097
============
9198

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,12 @@ def __init__(self):
244244
}};\n"""
245245

246246
sepconv1d_function_template = (
247-
'nnet::separable_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {d}, {p}, {z}, {b});'
247+
'nnet::separable_conv_1d_{data_format}<{input_t}, {dw_output_t}, {output_t}, {config}>('
248+
'{input}, {output}, {d}, {p}, {z}, {b});'
248249
)
249250
sepconv2d_function_template = (
250-
'nnet::separable_conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {d}, {p}, {z}, {b});'
251+
'nnet::separable_conv_2d_{data_format}<{input_t}, {dw_output_t}, {output_t}, {config}>('
252+
'{input}, {output}, {d}, {p}, {z}, {b});'
251253
)
252254

253255
sepconv1d_include_list = ['nnet_utils/nnet_conv1d.h', 'nnet_utils/nnet_sepconv1d_stream.h']
@@ -273,14 +275,17 @@ def format(self, node):
273275

274276
# Depthwise config
275277
params = self._default_config_params(node)
278+
# Override bias and bias_t since these are zeros in depthwise step of SepConv1D
279+
params['bias'] = params['zero_bias']
280+
params['bias_t'] = params['zero_bias_t']
276281
params['n_filt'] = params['n_chan'] # In depthwise step n_chan == n_filt
277282
params['dilation'] = node.get_attr('dilation', 1)
278283
params['nzeros'] = node.get_weights('depthwise').nzeros
279284
params['index'] = str(node.index) + '_depthwise'
280285
params['weight_t'] = node.get_weights('depthwise').type
281286
params['fill_fn'] = 'FillConv1DBuffer'
282287

283-
if node.get_attr("unscaled"):
288+
if node.get_attr('unscaled'):
284289
params['scale_index_type'] = 'scale_index_unscaled'
285290
else:
286291
params['scale_index_type'] = 'scale_index_regular'
@@ -301,14 +306,11 @@ def format(self, node):
301306
depthwise_mult_config = self.depthwise_mult_template.format(**mult_params)
302307

303308
# Pointwise config
304-
params = self._default_config_params()
305-
input_shape = self.get_input_variable().shape
306-
if self.get_attr('data_format') == 'channels_last':
307-
params['in_width'] = '*'.join([str(k) for k in input_shape[:-1]])
308-
params['n_chan'] = input_shape[-1]
309+
params = self._default_config_params(node)
310+
if node.get_attr('data_format') == 'channels_last':
311+
params['in_width'] = node.get_output_variable().shape[0]
309312
else:
310-
params['in_width'] = '*'.join([str(k) for k in input_shape[1:]])
311-
params['n_chan'] = input_shape[0]
313+
params['in_width'] = node.get_output_variable().shape[1]
312314

313315
params['filt_width'] = 1
314316
params['stride_width'] = 1
@@ -320,7 +322,7 @@ def format(self, node):
320322
params['instructions'] = '0'
321323
params['fill_fn'] = 'FillConv1DBuffer'
322324

323-
if node.get_attr("unscaled"):
325+
if node.get_attr('unscaled'):
324326
params['scale_index_type'] = 'scale_index_unscaled'
325327
else:
326328
params['scale_index_type'] = 'scale_index_regular'
@@ -360,6 +362,7 @@ def __init__(self):
360362

361363
def format(self, node):
362364
params = self._default_function_params(node)
365+
params['dw_output_t'] = node.get_attr('dw_output_t').name
363366
params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl'
364367
params['d'] = node.get_weights('depthwise').name
365368
params['p'] = node.get_weights('pointwise').name
@@ -398,12 +401,12 @@ def format(self, node):
398401
params['weight_t'] = node.get_weights('depthwise').type
399402
params['fill_fn'] = 'FillConv2DBuffer'
400403

401-
if node.get_attr("unscaled_h"):
404+
if node.get_attr('unscaled_h'):
402405
params['scale_index_height_type'] = 'scale_index_unscaled'
403406
else:
404407
params['scale_index_height_type'] = 'scale_index_regular'
405408

406-
if node.get_attr("unscaled_w"):
409+
if node.get_attr('unscaled_w'):
407410
params['scale_index_width_type'] = 'scale_index_unscaled'
408411
else:
409412
params['scale_index_width_type'] = 'scale_index_regular'
@@ -443,12 +446,12 @@ def format(self, node):
443446
params['instructions'] = '0'
444447
params['fill_fn'] = 'FillConv2DBuffer'
445448

446-
if node.get_attr("unscaled_h"):
449+
if node.get_attr('unscaled_h'):
447450
params['scale_index_height_type'] = 'scale_index_unscaled'
448451
else:
449452
params['scale_index_height_type'] = 'scale_index_regular'
450453

451-
if node.get_attr("unscaled_w"):
454+
if node.get_attr('unscaled_w'):
452455
params['scale_index_width_type'] = 'scale_index_unscaled'
453456
else:
454457
params['scale_index_width_type'] = 'scale_index_regular'
@@ -487,6 +490,7 @@ def __init__(self):
487490

488491
def format(self, node):
489492
params = self._default_function_params(node)
493+
params['dw_output_t'] = node.get_attr('dw_output_t').name
490494
params['data_format'] = 'cf' if node.get_attr('data_format') == 'channels_first' else 'cl'
491495
params['d'] = node.get_weights('depthwise').name
492496
params['p'] = node.get_weights('pointwise').name

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
Softmax,
2929
)
3030
from hls4ml.model.optimizer import get_backend_passes, layer_optimizer
31-
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType
31+
from hls4ml.model.types import FixedPrecisionType, IntegerPrecisionType, NamedType, PackedType
3232
from hls4ml.report import parse_vivado_report
3333
from hls4ml.utils.fixed_point_utils import ceil_log2
3434

@@ -75,6 +75,12 @@ def _register_layer_attributes(self):
7575
attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer'))
7676
self.attribute_map[layer] = attrs
7777

78+
sep_conv_layers = [SeparableConv1D, SeparableConv2D]
79+
for layer in sep_conv_layers:
80+
attrs = self.attribute_map.get(layer, [])
81+
attrs.append(TypeAttribute('dw_output', default=FixedPrecisionType(18, 8)))
82+
self.attribute_map[layer] = attrs
83+
7884
def _register_flows(self):
7985
initializers = self._get_layer_initializers()
8086
init_flow = register_flow('init_layers', initializers, requires=['optimize'], backend=self.name)
@@ -288,6 +294,15 @@ def init_sepconv1d(self, layer):
288294
) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly
289295
layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower())
290296

297+
# Set the output type of the depthwise phase
298+
dw_out_precision, _ = layer.model.config.get_precision(layer, 'dw_output')
299+
dw_out_name = layer.name + '_dw_out_t'
300+
if layer.model.config.get_config_value('IOType') == 'io_stream':
301+
dw_output_t = PackedType(dw_out_name, dw_out_precision, layer.get_attr('n_chan'), n_pack=1)
302+
else:
303+
dw_output_t = NamedType(dw_out_name, dw_out_precision)
304+
layer.set_attr('dw_output_t', dw_output_t)
305+
291306
@layer_optimizer(Conv2D)
292307
def init_conv2d(self, layer):
293308
if len(layer.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
@@ -334,6 +349,15 @@ def init_sepconv2d(self, layer):
334349
) # TODO Once we have SeparableConv implementation for io_parallel this should be set properly
335350
layer.set_attr('implementation', layer.model.config.get_conv_implementation(layer).lower())
336351

352+
# Set the output type of the depthwise phase
353+
dw_out_precision, _ = layer.model.config.get_precision(layer, 'dw_output')
354+
dw_out_name = layer.name + '_dw_out_t'
355+
if layer.model.config.get_config_value('IOType') == 'io_stream':
356+
dw_output_t = PackedType(dw_out_name, dw_out_precision, layer.get_attr('n_chan'), n_pack=1)
357+
else:
358+
dw_output_t = NamedType(dw_out_name, dw_out_precision)
359+
layer.set_attr('dw_output_t', dw_output_t)
360+
337361
@layer_optimizer(DepthwiseConv2D)
338362
def init_depconv2d(self, layer):
339363
if layer.model.config.is_resource_strategy(layer):

hls4ml/converters/keras/qkeras.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,27 @@ def parse_qdepthwiseqconv_layer(keras_layer, input_names, input_shapes, data_rea
5454
layer, output_shape = parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader)
5555

5656
layer['depthwise_quantizer'] = get_quantizer_from_config(keras_layer, 'depthwise')
57+
58+
if keras_layer['config']['bias_quantizer'] is not None:
59+
layer['bias_quantizer'] = get_quantizer_from_config(keras_layer, 'bias')
60+
else:
61+
layer['bias_quantizer'] = None
62+
63+
return layer, output_shape
64+
65+
66+
@keras_handler('QSeparableConv1D', 'QSeparableConv2D')
67+
def parse_qsepconv_layer(keras_layer, input_names, input_shapes, data_reader):
68+
assert 'QSeparableConv' in keras_layer['class_name']
69+
70+
if '1D' in keras_layer['class_name']:
71+
layer, output_shape = parse_conv1d_layer(keras_layer, input_names, input_shapes, data_reader)
72+
elif '2D' in keras_layer['class_name']:
73+
layer, output_shape = parse_conv2d_layer(keras_layer, input_names, input_shapes, data_reader)
74+
75+
layer['depthwise_quantizer'] = get_quantizer_from_config(keras_layer, 'depthwise')
76+
layer['pointwise_quantizer'] = get_quantizer_from_config(keras_layer, 'pointwise')
77+
5778
if keras_layer['config']['bias_quantizer'] is not None:
5879
layer['bias_quantizer'] = get_quantizer_from_config(keras_layer, 'bias')
5980
else:

hls4ml/model/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,6 @@ class Layer:
5656
ConfigurableAttribute('trace', default=False),
5757
TypeAttribute('result'),
5858
]
59-
""""""
6059

6160
@classproperty
6261
def expected_attributes(cls):
@@ -1343,8 +1342,10 @@ def initialize(self):
13431342
'QConv2D': Conv2D,
13441343
'QConv2DBatchnorm': Conv2DBatchnorm,
13451344
'SeparableConv1D': SeparableConv1D,
1345+
'QSeparableConv1D': SeparableConv1D,
13461346
'DepthwiseConv1D': DepthwiseConv1D,
13471347
'SeparableConv2D': SeparableConv2D,
1348+
'QSeparableConv2D': SeparableConv2D,
13481349
'DepthwiseConv2D': DepthwiseConv2D,
13491350
'QDepthwiseConv2D': DepthwiseConv2D,
13501351
'BatchNormalization': BatchNormalization,

hls4ml/model/profiling.py

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -343,21 +343,22 @@ def activations_keras(model, X, fmt='longform', plot='boxplot'):
343343
# return summary statistics for matplotlib.axes.Axes.bxp
344344
# or histogram bin edges and heights
345345
data = []
346-
347-
for layer in model.layers:
348-
print(f" {layer.name}")
349-
if not isinstance(layer, keras.layers.InputLayer):
350-
y = _get_output(layer, X, model.input).flatten()
351-
y = abs(y[y != 0])
352-
if len(y) == 0:
353-
print(f'Activations for {layer.name} are only zeros, ignoring.')
354-
continue
355-
if fmt == 'longform':
356-
data['x'].extend(y.tolist())
357-
data['weight'].extend([layer.name for i in range(len(y))])
358-
elif fmt == 'summary':
359-
data.append(array_to_summary(y, fmt=plot))
360-
data[-1]['weight'] = layer.name
346+
outputs = _get_outputs(
347+
[layer for layer in model.layers if not isinstance(layer, keras.layers.InputLayer)], X, model.input
348+
)
349+
for layer_name, y in outputs.items():
350+
print(f" {layer_name}")
351+
y = y.flatten()
352+
y = abs(y[y != 0])
353+
if len(y) == 0:
354+
print(f'Activations for {layer_name} are only zeros, ignoring.')
355+
continue
356+
if fmt == 'longform':
357+
data['x'].extend(y.tolist())
358+
data['weight'].extend([layer_name for i in range(len(y))])
359+
elif fmt == 'summary':
360+
data.append(array_to_summary(y, fmt=plot))
361+
data[-1]['weight'] = layer_name
361362

362363
if fmt == 'longform':
363364
data = pandas.DataFrame(data)
@@ -544,10 +545,10 @@ def _is_ignored_layer(layer):
544545
return False
545546

546547

547-
def _get_output(layer, X, model_input):
548-
"""Get output of partial model"""
549-
partial_model = keras.models.Model(inputs=model_input, outputs=layer.output)
550-
y = partial_model.predict(X)
548+
def _get_outputs(layers, X, model_input):
549+
"""Get outputs of intermediate layers"""
550+
partial_models = keras.models.Model(inputs=model_input, outputs=[layer.output for layer in layers])
551+
y = partial_models.predict(X)
551552
return y
552553

553554

@@ -562,37 +563,30 @@ def get_ymodel_keras(keras_model, X):
562563
Returns:
563564
dict: A dictionary in the form {"layer_name": ouput array of layer}.
564565
"""
565-
566566
ymodel = {}
567-
567+
traced_layers = []
568+
layer_names = []
568569
for layer in keras_model.layers:
569-
print(f"Processing {layer.name} in Keras model...")
570-
if not _is_ignored_layer(layer):
571-
# If the layer has activation integrated then separate them
572-
# Note that if the layer is a standalone activation layer then skip this
573-
if hasattr(layer, 'activation') and not (
574-
isinstance(layer, keras.layers.Activation) or isinstance(layer, qkeras.qlayers.QActivation)
575-
):
576-
if layer.activation:
577-
if layer.activation.__class__.__name__ == "linear":
578-
ymodel[layer.name] = _get_output(layer, X, keras_model.input)
579-
580-
else:
581-
temp_activation = layer.activation
582-
layer.activation = None
583-
# Get output for layer without activation
584-
ymodel[layer.name] = _get_output(layer, X, keras_model.input)
585-
586-
# Add the activation back
587-
layer.activation = temp_activation
588-
# Get ouput for activation
589-
ymodel[layer.name + f"_{temp_activation.__class__.__name__}"] = _get_output(
590-
layer, X, keras_model.input
591-
)
592-
else:
593-
ymodel[layer.name] = _get_output(layer, X, keras_model.input)
594-
else:
595-
ymodel[layer.name] = _get_output(layer, X, keras_model.input)
570+
if _is_ignored_layer(layer):
571+
continue
572+
# If the layer has activation integrated then separate them
573+
# Note that if the layer is a standalone activation layer then skip this
574+
name = layer.name
575+
if (
576+
hasattr(layer, "activation")
577+
and layer.activation.__name__ != "linear"
578+
and not isinstance(layer, (keras.layers.Activation, qkeras.qlayers.QActivation))
579+
):
580+
tmp_activation = layer.activation
581+
layer.activation = None
582+
ymodel.update({layer.name: _get_outputs([layer], X, keras_model.input)})
583+
layer.activation = tmp_activation
584+
name = layer.name + f"_{tmp_activation.__name__}"
585+
traced_layers.append(layer)
586+
layer_names.append(name)
587+
outputs = _get_outputs(traced_layers, X, keras_model.input)
588+
for name, output in zip(layer_names, outputs):
589+
ymodel[name] = output
596590
print("Done taking outputs for Keras model.")
597591
return ymodel
598592

hls4ml/model/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def update_precision(self, new_precision):
564564
# to right of decimal point
565565
lsb = 2**-new_precision.fractional
566566
decimal_spaces = len(str(lsb).split('.')[1])
567-
self.precision_fmt = f'{{:{decimal_spaces}f}}'
567+
self.precision_fmt = f'{{:.{decimal_spaces}f}}'
568568
else:
569569
self.precision_fmt = '{:.0f}'
570570
else:

0 commit comments

Comments
 (0)