Skip to content

Commit 2fc8941

Browse files
authored
Merge pull request #881 from jmduarte/split_pointwise_conv_by_rf_codegen
Pointwise Conv1D with code generation for "Latency" strategy (update of #811)
2 parents 22878ce + 4a1c25a commit 2fc8941

File tree

14 files changed

+354
-47
lines changed

14 files changed

+354
-47
lines changed

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
typedef {config_t} mult_config;
6161
template<unsigned K, unsigned S, unsigned W>
6262
using scale_index = nnet::{scale_index_type}<K, S, W>;
63+
template<class data_T, class res_T, class CONFIG_T>
64+
using conv_kernel = nnet::{conv_fn}<data_T, res_T, CONFIG_T>;
6365
}};
6466
const ap_uint<config{index}::filt_width> config{index}::pixels[] = {{{instructions}}};\n"""
6567

@@ -93,11 +95,30 @@ def format(self, node):
9395
else:
9496
params['fill_fn'] = 'FillConv1DBuffer'
9597

98+
is_pointwise_parallel_latency = (
99+
node.get_attr('filt_width') == 1
100+
and node.get_attr('strategy').lower() == 'latency'
101+
and node.model.config.get_config_value('IOType') == 'io_parallel'
102+
)
103+
if is_pointwise_parallel_latency:
104+
params['conv_fn'] = f'pointwise_conv_{node.index}'
105+
else:
106+
if node.get_attr('strategy').lower() == 'latency':
107+
params['conv_fn'] = 'Conv1DLatency'
108+
else:
109+
params['conv_fn'] = 'Conv1DResource'
110+
96111
conv_config = self.template.format(**params)
97112

98113
mult_params = self._default_config_params(node)
99-
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
100-
mult_params['n_out'] = node.get_attr('n_filt')
114+
if is_pointwise_parallel_latency:
115+
mult_params['n_in'] = int(
116+
node.get_attr('in_width') * node.get_attr('n_chan') * node.get_attr('filt_width') / mult_params['reuse']
117+
)
118+
mult_params['n_out'] = int(node.get_attr('in_width') * node.get_attr('n_filt') / mult_params['reuse'])
119+
else:
120+
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
121+
mult_params['n_out'] = node.get_attr('n_filt')
101122
mult_params['nzeros'] = node.get_weights('weight').nzeros
102123
mult_params['product_type'] = get_backend('vivado').product_type(
103124
node.get_input_variable().type.precision, node.get_weights('weight').type.precision
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from hls4ml.model.layers import Conv1D
2+
from hls4ml.model.optimizer import OptimizerPass
3+
from hls4ml.model.types import Source
4+
5+
6+
def generate_pointwise_conv1d_fn(layer_idx, reuse_factor=1):
7+
"""Generate a C++ function for a pointwise convolution layer.
8+
9+
Args:
10+
layer_idx (int): Index of layer ('index' attribute).
11+
reuse_factor (int): Number of partitions to divide the input into.
12+
13+
Returns:
14+
str: Generated C++ function
15+
"""
16+
17+
generated_code = (
18+
'template<class data_T, class res_T, typename CONFIG_T>\n'
19+
'class pointwise_conv_{index} : public Conv1DKernel<data_T, res_T, CONFIG_T> {{\n'
20+
' public:\n'
21+
' static void conv(\n'
22+
' data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n'
23+
' res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],\n'
24+
' typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],\n'
25+
' typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {{\n'
26+
' data_T data_tmp[CONFIG_T::reuse_factor][CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor];\n' # noqa: E501
27+
' #pragma HLS ARRAY_PARTITION variable=data_tmp complete dim=0\n'
28+
' res_T res_tmp[CONFIG_T::reuse_factor][CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor];\n' # noqa: E501
29+
' #pragma HLS ARRAY_PARTITION variable=res_tmp complete dim=0\n\n'
30+
' RFInputLoop:\n'
31+
' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {{\n'
32+
' #pragma HLS UNROLL\n'
33+
' InnerInputLoop:\n'
34+
' for (int ii = 0; ii < CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor; ii++) {{\n'
35+
' #pragma HLS UNROLL\n'
36+
' data_tmp[jj][ii] = data[jj * CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor + ii];\n' # noqa: E501
37+
' }}\n'
38+
' }}\n\n'
39+
).format(index=layer_idx)
40+
indent = ' '
41+
for i in range(reuse_factor):
42+
generated_code += indent
43+
generated_code += (
44+
f'pointwise_conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data_tmp[{i}], res_tmp[{i}], weights, biases);\n'
45+
)
46+
47+
generated_code += (
48+
'\n'
49+
' RFOutputLoop:\n'
50+
' for (int jj = 0; jj < CONFIG_T::reuse_factor; jj++) {\n'
51+
' #pragma HLS UNROLL\n'
52+
' InnerOutputLoop:\n'
53+
' for (int ii = 0; ii < CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor; ii++) {\n'
54+
' #pragma HLS UNROLL\n'
55+
' res[jj * CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor + ii] = res_tmp[jj][ii];\n' # noqa: E501
56+
' }\n'
57+
' }\n'
58+
' }\n'
59+
'};\n'
60+
)
61+
62+
return generated_code
63+
64+
65+
class GeneratePointwiseConv1D(OptimizerPass):
66+
'''Generates code for pointwise 1D convolution'''
67+
68+
def match(self, node):
69+
return (
70+
isinstance(node, Conv1D)
71+
and node.model.config.get_config_value('IOType') == 'io_parallel'
72+
and node.get_attr('filt_width') == 1
73+
)
74+
75+
def transform(self, model, node):
76+
self._generate_pointwise_conv1d(node)
77+
78+
def _generate_pointwise_conv1d(self, node):
79+
code_str = generate_pointwise_conv1d_fn(
80+
node.get_attr('index'),
81+
node.get_attr('reuse_factor'),
82+
)
83+
84+
node.set_attr('pointwise_conv1d_codegen', Source(code_str))

hls4ml/backends/vivado/vivado_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ def _register_layer_attributes(self):
7070
cnn_layers = [Conv1D, Conv2D, SeparableConv1D, SeparableConv2D, DepthwiseConv2D, Pooling1D, Pooling2D]
7171
for layer in cnn_layers:
7272
attrs = self.attribute_map.get(layer, [])
73-
# attrs.append(ConfigurableAttribute('conv_implementation', value_type=str, default='LineBuffer'))
7473
attrs.append(ChoiceAttribute('conv_implementation', choices=['LineBuffer', 'Encoded'], default='LineBuffer'))
7574
self.attribute_map[layer] = attrs
7675

@@ -114,6 +113,7 @@ def _register_flows(self):
114113
'vivado:generate_conv_streaming_instructions',
115114
'vivado:apply_resource_strategy',
116115
'vivado:generate_conv_im2col',
116+
'vivado:generate_pointwise_conv1_d',
117117
'vivado:generate_unrolled_dense_resource',
118118
'vivado:set_pipeline_style',
119119
]

hls4ml/templates/vitis/nnet_utils/nnet_conv1d.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "nnet_common.h"
55
#include "nnet_conv1d_latency.h"
66
#include "nnet_conv1d_resource.h"
7+
#include "nnet_function_stubs.h"
78
#include <cstdlib>
89

910
namespace nnet {
@@ -38,11 +39,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO
3839
// Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully.
3940
//#pragma HLS INLINE recursive
4041

41-
if (CONFIG_T::strategy == nnet::latency) {
42-
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
43-
} else {
44-
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
45-
}
42+
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
4643
}
4744

4845
template <class data_T, class res_T, typename CONFIG_T>
@@ -55,13 +52,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
5552
// Inlining helps reduce latency, but may also cause timing issues in some cases, use carefully.
5653
//#pragma HLS INLINE recursive
5754

58-
// Nothing special to be done for io_parallel implementation
59-
if (CONFIG_T::strategy == nnet::latency) {
55+
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
56+
}
57+
58+
template <class data_T, class res_T, typename CONFIG_T> class Conv1DLatency : public Conv1DKernel<data_T, res_T, CONFIG_T> {
59+
public:
60+
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
61+
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
62+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
63+
//#pragma HLS INLINE region
6064
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
61-
} else {
65+
}
66+
};
67+
68+
template <class data_T, class res_T, typename CONFIG_T> class Conv1DResource : public Conv1DKernel<data_T, res_T, CONFIG_T> {
69+
public:
70+
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
71+
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
72+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
73+
//#pragma HLS INLINE region
6274
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
6375
}
64-
}
76+
};
6577

6678
} // namespace nnet
6779

hls4ml/templates/vitis/nnet_utils/nnet_conv1d_latency.h

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,5 +85,83 @@ void conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
8585
}
8686
}
8787

88+
template <class data_T, class res_T, typename CONFIG_T>
89+
void pointwise_conv_1d_latency_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan / CONFIG_T::reuse_factor],
90+
res_T res[CONFIG_T::out_width * CONFIG_T::n_filt / CONFIG_T::reuse_factor],
91+
typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],
92+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
93+
assert(CONFIG_T::filt_width == 1);
94+
95+
typename CONFIG_T::accum_t mult[CONFIG_T::out_width * CONFIG_T::n_filt * CONFIG_T::n_chan / CONFIG_T::reuse_factor];
96+
typename CONFIG_T::accum_t acc[CONFIG_T::out_width / CONFIG_T::reuse_factor][CONFIG_T::n_filt];
97+
98+
#pragma HLS ARRAY_PARTITION variable=mult complete dim=0
99+
#pragma HLS ARRAY_PARTITION variable=acc complete dim=0
100+
101+
// Use a function_instantiate in case it helps to explicitly optimize unchanging weights/biases
102+
#pragma HLS function_instantiate variable=weights,biases
103+
104+
// Parallel mode
105+
#pragma HLS PIPELINE II=CONFIG_T::reuse_factor
106+
#pragma HLS ARRAY_PARTITION variable=weights complete dim=0
107+
#pragma HLS ARRAY_PARTITION variable=biases complete dim=0
108+
109+
// Limit multipliers to control parallelization
110+
#pragma HLS ALLOCATION operation instances=mul limit=CONFIG_T::mult_config::multiplier_limit
111+
112+
// Convolve, saving all multiplication results to accumulate later
113+
ConvOut:
114+
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
115+
ConvFilt:
116+
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
117+
ConvChan:
118+
for (int cc = 0; cc < CONFIG_T::n_chan; cc++) {
119+
#pragma HLS UNROLL
120+
int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc;
121+
int index_weight = cc * CONFIG_T::n_filt + ff;
122+
int index_data = (ii * CONFIG_T::stride_width - CONFIG_T::pad_left) * CONFIG_T::n_chan + cc;
123+
124+
if ((ii * CONFIG_T::stride_width) < CONFIG_T::pad_left ||
125+
(ii * CONFIG_T::stride_width) >= (CONFIG_T::pad_left + CONFIG_T::in_width)) {
126+
mult[index_mult] = 0;
127+
} else {
128+
mult[index_mult] = CONFIG_T::mult_config::template product<data_T, typename CONFIG_T::weight_t>::product(
129+
data[index_data], weights[index_weight]);
130+
}
131+
} // end channel loop
132+
} // end filter loop
133+
} // end output loop
134+
135+
// Initialize accumulator with input biases
136+
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
137+
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
138+
#pragma HLS UNROLL
139+
acc[ii][ff] = biases[ff];
140+
}
141+
}
142+
143+
// Accumulate multiplication result
144+
AccumOut:
145+
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
146+
AccumFilt:
147+
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
148+
// Do "dot product" sum within filter and sum over channels
149+
AccumChan:
150+
for (int cc = 0; cc < CONFIG_T::n_chan; cc++) {
151+
int index_mult = ii * CONFIG_T::n_filt * CONFIG_T::n_chan + ff * CONFIG_T::n_chan + cc;
152+
acc[ii][ff] += mult[index_mult];
153+
} // end channel loop
154+
} // end filter loop
155+
} // end output loop
156+
157+
// Cast to "res_t" type
158+
for (int ii = 0; ii < CONFIG_T::out_width / CONFIG_T::reuse_factor; ii++) {
159+
for (int ff = 0; ff < CONFIG_T::n_filt; ff++) {
160+
#pragma HLS UNROLL
161+
res[ii * CONFIG_T::n_filt + ff] = cast<data_T, res_T, typename CONFIG_T::mult_config>(acc[ii][ff]);
162+
}
163+
}
164+
}
165+
88166
} // namespace nnet
89167
#endif

hls4ml/templates/vivado/build_prj.tcl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ if {$opt(reset)} {
161161
} else {
162162
open_solution "solution1"
163163
}
164-
catch {config_array_partition -maximum_size 4096}
164+
catch {config_array_partition -maximum_size $maximum_size}
165165
config_compile -name_max_length 80
166166
set_part $part
167167
config_schedule -enable_dsp_full_reg=false

hls4ml/templates/vivado/nnet_utils/nnet_code_gen.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef NNET_INSTR_GEN_H_
22
#define NNET_INSTR_GEN_H_
33

4+
#include "nnet_conv1d_latency.h"
45
#include "nnet_helpers.h"
56

67
#include "hls_stream.h"
@@ -10,6 +11,16 @@
1011

1112
namespace nnet {
1213

14+
template <class data_T, class res_T, typename CONFIG_T> class PointwiseConv1D {
15+
public:
16+
static void pointwise_conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
17+
res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
18+
typename CONFIG_T::weight_t weights[CONFIG_T::n_chan * CONFIG_T::n_filt],
19+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
20+
// To be implemented in subclasses
21+
}
22+
};
23+
1324
// hls4ml insert code
1425

1526
} // namespace nnet

hls4ml/templates/vivado/nnet_utils/nnet_common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define NNET_COMMON_H_
33

44
#include "ap_fixed.h"
5+
#include "nnet_helpers.h"
56

67
// This is a substitute for "ceil(n/(float)d)".
78
#define DIV_ROUNDUP(n, d) ((n + d - 1) / d)

hls4ml/templates/vivado/nnet_utils/nnet_conv1d.h

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include "nnet_common.h"
55
#include "nnet_conv1d_latency.h"
66
#include "nnet_conv1d_resource.h"
7+
#include "nnet_function_stubs.h"
78
#include <cstdlib>
89

910
namespace nnet {
@@ -37,11 +38,7 @@ void conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CO
3738
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
3839
#pragma HLS INLINE region
3940

40-
if (CONFIG_T::strategy == nnet::latency) {
41-
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
42-
} else {
43-
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
44-
}
41+
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
4542
}
4643

4744
template <class data_T, class res_T, typename CONFIG_T>
@@ -53,13 +50,28 @@ void pointwise_conv_1d_cl(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],
5350

5451
#pragma HLS INLINE region
5552

56-
// Nothing special to be done for io_parallel implementation
57-
if (CONFIG_T::strategy == nnet::latency) {
53+
CONFIG_T::template conv_kernel<data_T, res_T, CONFIG_T>::conv(data, res, weights, biases);
54+
}
55+
56+
template <class data_T, class res_T, typename CONFIG_T> class Conv1DLatency : public Conv1DKernel<data_T, res_T, CONFIG_T> {
57+
public:
58+
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
59+
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
60+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
61+
#pragma HLS INLINE region
5862
conv_1d_latency_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
59-
} else {
63+
}
64+
};
65+
66+
template <class data_T, class res_T, typename CONFIG_T> class Conv1DResource : public Conv1DKernel<data_T, res_T, CONFIG_T> {
67+
public:
68+
static void conv(data_T data[CONFIG_T::in_width * CONFIG_T::n_chan], res_T res[CONFIG_T::out_width * CONFIG_T::n_filt],
69+
typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
70+
typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]) {
71+
#pragma HLS INLINE region
6072
conv_1d_resource_cl<data_T, res_T, CONFIG_T>(data, res, weights, biases);
6173
}
62-
}
74+
};
6375

6476
} // namespace nnet
6577

0 commit comments

Comments
 (0)