Skip to content

Commit d56dc73

Browse files
committed
vladimir comments
1 parent 9e3fc8d commit d56dc73

File tree

8 files changed

+114
-110
lines changed

8 files changed

+114
-110
lines changed

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
typedef {config_t} mult_config;
6161
template<unsigned K, unsigned S, unsigned W>
6262
using scale_index = nnet::{scale_index_type}<K, S, W>;
63+
template<class data_T, class res_T, class CONFIG_T>
64+
using conv_kernel = nnet::{conv_fn}<data_T, res_T, CONFIG_T>;
6365
}};
6466
const ap_uint<config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""
6567

@@ -93,16 +95,24 @@ def format(self, node):
9395
else:
9496
params['fill_fn'] = 'FillConv1DBuffer'
9597

96-
if node.get_attr('filt_width') == 1 and node.model.config.get_config_value('IOType') == 'io_parallel':
97-
params['pointwise_fn'] = f'pointwise_conv_{node.index}'
98+
is_pointwise_parallel_latency = node.get_attr('filt_width') == 1 and node.get_attr('strategy').lower() == 'latency' and node.model.config.get_config_value('IOType') == 'io_parallel'
99+
if is_pointwise_parallel_latency:
100+
params['conv_fn'] = f'pointwise_conv_{node.index}'
98101
else:
99-
params['pointwise_fn'] = 'PointwiseConv1D'
102+
if node.get_attr('strategy').lower() == 'latency':
103+
params['conv_fn'] = 'Conv1DLatency'
104+
elif node.get_attr('strategy').lower() == 'resource':
105+
params['conv_fn'] = 'Conv1DResource'
100106

101107
conv_config = self.template.format(**params)
102108

103109
mult_params = self._default_config_params(node)
104-
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
105-
mult_params['n_out'] = node.get_attr('n_filt')
110+
if is_pointwise_parallel_latency:
111+
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') / mult_params['reuse']
112+
mult_params['n_out'] = node.get_attr('n_filt') / mult_params['reuse']
113+
else:
114+
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
115+
mult_params['n_out'] = node.get_attr('n_filt')
106116
mult_params['nzeros'] = node.get_weights('weight').nzeros
107117
mult_params['product_type'] = get_backend('vivado').product_type(
108118
node.get_input_variable().type.precision, node.get_weights('weight').type.precision

hls4ml/backends/vivado/passes/pointwise.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,45 +4,13 @@
44
Conv1DFunctionTemplate,
55
Conv2DConfigTemplate,
66
Conv2DFunctionTemplate,
7+
conv1d_config_template,
78
conv2d_config_template,
89
conv_mult_config_template,
910
)
1011
from hls4ml.model.layers import register_layer
1112
from hls4ml.model.optimizer import OptimizerPass
1213

13-
pointwise_conv1d_config_template = """struct config{index} : nnet::conv1d_config {{
14-
static const unsigned pad_left = {pad_left};
15-
static const unsigned pad_right = {pad_right};
16-
static const unsigned in_width = {in_width};
17-
static const unsigned n_chan = {n_chan};
18-
static const unsigned filt_width = {filt_width};
19-
static const unsigned kernel_size = filt_width;
20-
static const unsigned n_filt = {n_filt};
21-
static const unsigned stride_width = {stride_width};
22-
static const unsigned dilation = {dilation};
23-
static const unsigned out_width = {out_width};
24-
static const unsigned reuse_factor = {reuse};
25-
static const unsigned n_zeros = {nzeros};
26-
static const bool store_weights_in_bram = false;
27-
static const unsigned strategy = nnet::{strategy};
28-
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
29-
static const unsigned min_width = {min_width};
30-
static const ap_uint<filt_width> pixels[min_width];
31-
static const unsigned n_partitions = {n_partitions};
32-
static const unsigned n_pixels = out_width / n_partitions;
33-
template<class data_T, class CONFIG_T>
34-
using fill_buffer = nnet::{fill_fn}<data_T, CONFIG_T>;
35-
typedef {accum_t.name} accum_t;
36-
typedef {bias_t.name} bias_t;
37-
typedef {weight_t.name} weight_t;
38-
typedef {config_t} mult_config;
39-
template<unsigned K, unsigned S, unsigned W>
40-
using scale_index = nnet::{scale_index_type}<K, S, W>;
41-
template<class data_T, class res_T, class CONFIG_T>
42-
using pointwise_conv = nnet::{pointwise_fn}<data_T, res_T, CONFIG_T>;
43-
}};
44-
const ap_uint<config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""
45-
4614
pointwise_conv1d_function_template = (
4715
'nnet::pointwise_conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
4816
)
@@ -57,7 +25,7 @@
5725
class PointwiseConv1DConfigTemplate(Conv1DConfigTemplate):
5826
def __init__(self):
5927
super(Conv1DConfigTemplate, self).__init__(PointwiseConv1D)
60-
self.template = pointwise_conv1d_config_template
28+
self.template = conv1d_config_template
6129
self.mult_template = conv_mult_config_template
6230

6331

hls4ml/backends/vivado/passes/pointwise_codegen.py

Lines changed: 37 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -15,48 +15,48 @@ def generate_pointwise_conv1d_fn(layer_idx, reuse_factor=1):
1515
"""
1616

1717
generated_code = (
18-
"template<class data_T, class res_T, typename CONFIG_T>\n"
19-
"class pointwise_conv_{index} : public PointwiseConv1D<data_T, res_T, CONFIG_T> {{\n"
20-
" public:\n"
21-
" static void pointwise_conv(\n"
22-
" data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n"
23-
" res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],\n"
24-
" typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],\n"
25-
" typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {{\n"
26-
" data_T data_tmp[CONFIG_T::reuse_factor][CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor];\n" # noqa: E501
27-
" #pragma HLS ARRAY_PARTITION variable=data_tmp complete dim=0\n"
28-
" res_T res_tmp[CONFIG_T::reuse_factor][CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor];\n" # noqa: E501
29-
" #pragma HLS ARRAY_PARTITION variable=res_tmp complete dim=0\n\n"
30-
" RFInputLoop:\n"
31-
" for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {{\n"
32-
" #pragma HLS UNROLL\n"
33-
" InnerInputLoop:\n"
34-
" for (int ii = 0; ii < CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor; ii++) {{\n"
35-
" #pragma HLS UNROLL\n"
36-
" data_tmp[jj][ii] = data[jj * CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor + ii];\n" # noqa: E501
37-
" }}\n"
38-
" }}\n\n"
18+
'template<class data_T, class res_T, typename CONFIG_T>\n'
19+
'class pointwise_conv_{index} : public Conv1DKernel<data_T, res_T, CONFIG_T> {{\n'
20+
' public:\n'
21+
' static void conv(\n'
22+
' data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n'
23+
' res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],\n'
24+
' typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],\n'
25+
' typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {{\n'
26+
' data_T data_tmp[CONFIG_T::reuse_factor][CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor];\n' # noqa: E501
27+
' #pragma HLS ARRAY_PARTITION variable=data_tmp complete dim=0\n'
28+
' res_T res_tmp[CONFIG_T::reuse_factor][CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor];\n' # noqa: E501
29+
' #pragma HLS ARRAY_PARTITION variable=res_tmp complete dim=0\n\n'
30+
' RFInputLoop:\n'
31+
' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {{\n'
32+
' #pragma HLS UNROLL\n'
33+
' InnerInputLoop:\n'
34+
' for (int ii = 0; ii < CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor; ii++) {{\n'
35+
' #pragma HLS UNROLL\n'
36+
' data_tmp[jj][ii] = data[jj * CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor + ii];\n' # noqa: E501
37+
' }}\n'
38+
' }}\n\n'
3939
).format(index=layer_idx)
40-
indent = " "
40+
indent = ' '
4141
for i in range(reuse_factor):
4242
generated_code += indent
4343
generated_code += (
44-
f"pointwise_conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data_tmp[{i}], res_tmp[{i}], weights, biases);\n"
44+
f'pointwise_conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data_tmp[{i}], res_tmp[{i}], weights, biases);\n'
4545
)
4646

4747
generated_code += (
48-
"\n"
49-
" RFOutputLoop:\n"
50-
" for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {\n"
51-
" #pragma HLS UNROLL\n"
52-
" InnerOutputLoop:\n"
53-
" for (int ii = 0; ii < CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor; ii++) {\n"
54-
" #pragma HLS UNROLL\n"
55-
" res[jj * CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor + ii] = res_tmp[jj][ii];\n" # noqa: E501
56-
" }\n"
57-
" }\n"
58-
" }\n"
59-
"};\n"
48+
'\n'
49+
' RFOutputLoop:\n'
50+
' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {\n'
51+
' #pragma HLS UNROLL\n'
52+
' InnerOutputLoop:\n'
53+
' for (int ii = 0; ii < CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor; ii++) {\n'
54+
' #pragma HLS UNROLL\n'
55+
' res[jj * CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor + ii] = res_tmp[jj][ii];\n' # noqa: E501
56+
' }\n'
57+
' }\n'
58+
' }\n'
59+
'};\n'
6060
)
6161

6262
return generated_code
@@ -66,14 +66,10 @@ class GeneratePointwiseConv1D(OptimizerPass):
6666
'''Generates code for pointwise 1D convolution'''
6767

6868
def match(self, node):
69-
return isinstance(node, Conv1D) and node.model.config.get_config_value('IOType') == 'io_parallel'
69+
return isinstance(node, Conv1D) and node.model.config.get_config_value('IOType') == 'io_parallel' and node.get_attr('filt_width') == 1
7070

7171
def transform(self, model, node):
72-
node_class = node.__class__.__name__
73-
if '1D' in node_class:
74-
self._generate_pointwise_conv1d(node)
75-
else:
76-
raise Exception(f'Cannot generate instructions for node {node.name} ({node_class})')
72+
self._generate_pointwise_conv1d(node)
7773

7874
def _generate_pointwise_conv1d(self, node):
7975
code_str = generate_pointwise_conv1d_fn(

hls4ml/templates/vitis/nnet_utils/nnet_conv1d.h

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "nnet_common.h"
55
#include "nnet_conv1d_latency.h"
66
#include "nnet_conv1d_resource.h"
7+
#include "nnet_function_stubs.h"
78
#include <cstdlib>
89

910
namespace nnet {
@@ -38,11 +39,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO
3839
// Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully.
3940
//#pragma HLS INLINE recursive
4041

41-
if (CONFIG_T::strategy == nnet::latency) {
42-
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
43-
} else {
44-
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
45-
}
42+
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
4643
}
4744

4845
template <class data_T, class res_T, typename CONFIG_T>
@@ -55,13 +52,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
5552
// Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully.
5653
//#pragma HLS INLINE recursive
5754

58-
if (CONFIG_T::strategy == nnet::latency) {
59-
// Use pointwise unrolled implementation
60-
CONFIG_T::template pointwise_conv<data_T, res_T, CONFIG_T>::pointwise_conv(data, res, weights, biases);
61-
} else {
55+
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
56+
}
57+
58+
template <class data_T, class res_T, typename CONFIG_T> class Conv1DLatency : public Conv1DKernel<data_T, res_T, CONFIG_T> {
59+
public:
60+
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
61+
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
62+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
63+
//#pragma HLS INLINE region
64+
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
65+
}
66+
};
67+
68+
template <class data_T, class res_T, typename CONFIG_T> class Conv1DResource : public Conv1DKernel<data_T, res_T, CONFIG_T> {
69+
public:
70+
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
71+
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
72+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
73+
//#pragma HLS INLINE region
6274
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
6375
}
64-
}
76+
};
6577

6678
} // namespace nnet
6779

hls4ml/templates/vitis/nnet_utils/nnet_conv1d_latency.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_c
107107
#pragma HLS ARRAY_PARTITION variable=biases complete dim=0
108108

109109
// Limit multipliers to control parallelization
110-
constexpr unsigned multiplier_limit = DIV_ROUNDUP(
111-
CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan / CONFIG_T::reuse_factor, CONFIG_T::reuse_factor);
112-
#pragma HLS ALLOCATION operation instances=mul limit=multiplier_limit
110+
#pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit
113111

114112
// Convolve, saving all multiplication results to accumulate later
115113
ConvOut:
@@ -159,8 +157,8 @@ void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_c
159157
// Cast to "res_t" type
160158
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
161159
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
162-
#pragma HLS UNROLL
163-
res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]);
160+
#pragma HLS UNROLL
161+
res[ii * CONFIG_T::n_filt + ff] = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[ii][ff]);
164162
}
165163
}
166164
}

hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "nnet_common.h"
55
#include "nnet_conv1d_latency.h"
66
#include "nnet_conv1d_resource.h"
7+
#include "nnet_function_stubs.h"
78
#include <cstdlib>
89

910
namespace nnet {
@@ -37,11 +38,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO
3738
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
3839
#pragma HLS INLINE region
3940

40-
if (CONFIG_T::strategy == nnet::latency) {
41-
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
42-
} else {
43-
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
44-
}
41+
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
4542
}
4643

4744
template <class data_T, class res_T, typename CONFIG_T>
@@ -53,13 +50,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
5350

5451
#pragma HLS INLINE region
5552

56-
if (CONFIG_T::strategy == nnet::latency) {
57-
// Use pointwise unrolled implementation
58-
CONFIG_T::template pointwise_conv<data_T, res_T, CONFIG_T>::pointwise_conv(data, res, weights, biases);
59-
} else {
53+
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
54+
}
55+
56+
template <class data_T, class res_T, typename CONFIG_T> class Conv1DLatency : public Conv1DKernel<data_T, res_T, CONFIG_T> {
57+
public:
58+
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
59+
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
60+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
61+
#pragma HLS INLINE region
62+
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
63+
}
64+
};
65+
66+
template <class data_T, class res_T, typename CONFIG_T> class Conv1DResource : public Conv1DKernel<data_T, res_T, CONFIG_T> {
67+
public:
68+
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
69+
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
70+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
71+
#pragma HLS INLINE region
6072
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
6173
}
62-
}
74+
};
6375

6476
} // namespace nnet
6577

hls4ml/templates/vivado/nnet_utils/nnet_conv1d_latency.h

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,7 @@ void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_c
106106
#pragma HLS ARRAY_PARTITION variable=biases complete dim=0
107107

108108
// Limit multipliers to control parallelization
109-
constexpr unsigned multiplier_limit = DIV_ROUNDUP(
110-
CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan / CONFIG_T::reuse_factor, CONFIG_T::reuse_factor);
111-
#pragma HLS ALLOCATION operation instances=mul limit=multiplier_limit
109+
#pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit
112110

113111
// Convolve, saving all multiplication results to accumulate later
114112
ConvOut:
@@ -158,8 +156,8 @@ void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_c
158156
// Cast to "res_t" type
159157
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
160158
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
161-
#pragma HLS UNROLL
162-
res[ii * CONFIG_T::n_filt + ff] = (res_T)(acc[ii][ff]);
159+
#pragma HLS UNROLL
160+
res[ii * CONFIG_T::n_filt + ff] = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[ii][ff]);
163161
}
164162
}
165163
}

hls4ml/templates/vivado/nnet_utils/nnet_function_stubs.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ template <class data_T, class res_T, typename CONFIG_T> class DenseKernel {
3737
}
3838
};
3939

40+
template <class data_T, class res_T, typename CONFIG_T> class Conv1DKernel {
41+
public:
42+
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
43+
res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
44+
typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],
45+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
46+
// To be implemented in subclasses
47+
}
48+
};
49+
4050
} // namespace nnet
4151

4252
#endif

0 commit comments

Comments
 (0)