Skip to content

Commit e4a5988

Browse files
authored
Merge pull request #656 from bo3z/quartus-streaming-conv
There were no comments to not merge, so I'll go ahead and merge.
2 parents b180fe1 + 4db0002 commit e4a5988

26 files changed

+1143
-117
lines changed

hls4ml/backends/quartus/passes/convolution_templates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
"""
6060

6161
conv1d_function_template = 'nnet::conv_1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
62-
conv1d_include_list = ['nnet_utils/nnet_conv1d.h']
62+
conv1d_include_list = ['nnet_utils/nnet_conv1d.h', 'nnet_utils/nnet_conv1d_stream.h']
6363

6464
class Conv1DConfigTemplate(LayerConfigTemplate):
6565
def __init__(self):
@@ -134,7 +134,7 @@ def format(self, node):
134134
}};\n"""
135135

136136
conv2d_function_template = 'nnet::conv_2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
137-
conv2d_include_list = ['nnet_utils/nnet_conv2d.h']
137+
conv2d_include_list = ['nnet_utils/nnet_conv2d.h', 'nnet_utils/nnet_conv2d_stream.h']
138138

139139
class Conv2DConfigTemplate(LayerConfigTemplate):
140140
def __init__(self):

hls4ml/backends/quartus/passes/convolution_winograd.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ def match(self, node):
1515
weights_transformed = node.get_attr('_weights_transposed', False) == True
1616

1717
# User opted for Winograd
18-
implementation_is_winograd = node.get_attr('implementation', 'combination') == 'combination' or node.get_attr('implementation', 'combination') == 'winograd'
18+
implementation_is_winograd = node.get_attr('implementation', 'combination') == 'combination' or node.get_attr('implementation', 'combination') == 'winograd'
19+
20+
parallel_io_type = node.model.config.get_config_value('IOType') == 'io_parallel'
1921

2022
# Winograd algorithm-specific conditions
2123
if isinstance(node, Conv1D):
@@ -29,7 +31,7 @@ def match(self, node):
2931
# HLS Compiler fails to pipeline the entire component if Winograd loop only executes once
3032
loop_itr_gt_one = node.get_attr('out_width') > 2
3133

32-
winograd_conditions = filter_size_matches and stride_is_one and loop_itr_gt_one
34+
winograd_conditions = filter_size_matches and stride_is_one and loop_itr_gt_one and parallel_io_type
3335

3436
elif isinstance(node, (Conv2D)):
3537
# Winograd only applies to specific kernel sizes
@@ -44,7 +46,7 @@ def match(self, node):
4446

4547
padding_is_equal = node.get_attr('pad_top', 0) == node.get_attr('pad_bottom', 0) and node.get_attr('pad_left', 0) == node.get_attr('pad_right', 0)
4648

47-
winograd_conditions = filter_size_matches and stride_is_one and padding_is_equal and loop_itr_gt_one
49+
winograd_conditions = filter_size_matches and stride_is_one and padding_is_equal and loop_itr_gt_one and parallel_io_type
4850

4951
else:
5052
winograd_conditions = False

hls4ml/backends/quartus/passes/pointwise.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,15 @@ class OptimizePointwiseConv(OptimizerPass):
5858
def match(self, node):
5959
return node.class_name in ('Conv1D', 'Conv2D') and \
6060
node.get_attr('filt_height', 1) == 1 and \
61-
node.get_attr('filt_width') == 1
61+
node.get_attr('filt_width') == 1 and \
62+
node.model.config.get_config_value('IOType') == 'io_parallel'
6263

6364
def transform(self, model, node):
6465
dim = node.__class__.__name__[-2:] # '1D' or '2D'
6566
pw_node = model.make_node('PointwiseConv' + dim, node.name, copy(node.attributes), node.inputs.copy(), outputs=node.outputs.copy())
6667
if len(node.weights['weight'].data.shape) == 2: # This can happen if we assign weights of Dense layer to 1x1 Conv2D
6768
pw_node.weights['weight'].data = np.expand_dims(node.weights['weight'].data, axis=(0,1))
6869
pw_node.weights['bias'].data = node.weights['bias'].data
69-
# pw_node.weights['bias'].data = node.weights['bias'].data
70-
print("Here")
7170
model.replace_node(node, pw_node)
7271

73-
return True
72+
return True

hls4ml/backends/quartus/passes/pooling_templates.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,18 @@
99
1010
static const unsigned n_in = {n_in};
1111
static const unsigned n_out = {n_out};
12+
static const unsigned filt_width = {pool_width};
1213
1314
static const unsigned n_filt = {n_filt};
15+
static const unsigned n_chan = {n_filt};
16+
17+
static const unsigned in_width = {n_in};
1418
1519
static const unsigned pad_left = {pad_left};
1620
static const unsigned pad_right = {pad_right};
1721
1822
static const nnet::Pool_Op pool_op = nnet::{pool_op};
23+
typedef {accum_t.name} accum_t;
1924
}};\n"""
2025

2126
pooling2d_config_template = """struct config{index} : nnet::pooling2d_config {{
@@ -24,41 +29,47 @@
2429
2530
static const unsigned pool_height = {pool_height};
2631
static const unsigned pool_width = {pool_width};
32+
static const unsigned filt_height = {pool_height};
33+
static const unsigned filt_width = {pool_width};
2734
2835
static const unsigned in_height = {in_height};
2936
static const unsigned in_width = {in_width};
3037
static const unsigned out_height = {out_height};
3138
static const unsigned out_width = {out_width};
3239
3340
static const unsigned n_filt = {n_filt};
41+
static const unsigned n_chan = {n_filt};
3442
3543
static const unsigned pad_top = {pad_top};
3644
static const unsigned pad_bottom = {pad_bottom};
3745
static const unsigned pad_left = {pad_left};
3846
static const unsigned pad_right = {pad_right};
3947
4048
static const nnet::Pool_Op pool_op = nnet::{pool_op};
49+
typedef {accum_t.name} accum_t;
4150
}};\n"""
4251

4352
global_pooling1d_config_template = """struct config{index} : nnet::pooling1d_config {{
4453
static const unsigned n_in = {n_in};
4554
static const unsigned n_filt = {n_filt};
4655
static const nnet::Pool_Op pool_op = nnet::{pool_op};
56+
typedef {accum_t.name} accum_t;
4757
}};\n"""
4858

4959
global_pooling2d_config_template = """struct config{index} : nnet::pooling2d_config {{
5060
static const unsigned in_height = {in_height};
5161
static const unsigned in_width = {in_width};
5262
static const unsigned n_filt = {n_filt};
5363
static const nnet::Pool_Op pool_op = nnet::{pool_op};
64+
typedef {accum_t.name} accum_t;
5465
}};\n"""
5566

5667
pooling1d_function_template = 'nnet::pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
5768
pooling2d_function_template = 'nnet::pooling2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
5869
global_pooling1d_function_template = 'nnet::global_pooling1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
5970
global_pooling2d_function_template = 'nnet::global_pooling2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
6071

61-
pooling_include_list = ['nnet_utils/nnet_pooling.h']
72+
pooling_include_list = ['nnet_utils/nnet_pooling.h', 'nnet_utils/nnet_pooling_stream.h']
6273

6374
class PoolingConfigTemplate(LayerConfigTemplate):
6475
def __init__(self):

hls4ml/backends/quartus/passes/reshaping_templates.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
zeropad1d_function_template = 'nnet::zeropad1d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
2929
zeropad2d_function_template = 'nnet::zeropad2d_{data_format}<{input_t}, {output_t}, {config}>({input}, {output});'
3030

31-
padding_include_list = ['nnet_utils/nnet_padding.h']
31+
padding_include_list = ['nnet_utils/nnet_padding.h', 'nnet_utils/nnet_padding_stream.h']
3232

3333
class ZeroPaddingConfigTemplate(LayerConfigTemplate):
3434
def __init__(self):
@@ -72,7 +72,7 @@ def format(self, node):
7272
}};\n"""
7373

7474
resize_function_template = 'nnet::resize_{algorithm}<{input_t}, {config}>({input}, {output});'
75-
resize_include_list = ['nnet_utils/nnet_resize.h']
75+
resize_include_list = ['nnet_utils/nnet_resize.h', 'nnet_utils/nnet_resize_stream.h']
7676

7777
class ResizeConfigTemplate(LayerConfigTemplate):
7878
def __init__(self):
@@ -108,7 +108,7 @@ def format(self, node):
108108
}};\n"""
109109

110110
transpose_function_template = 'nnet::transpose_{dim}<{input_t}, {output_t}, {config}>({input}, {output});'
111-
transpose_include_list = ['nnet_utils/nnet_transpose.h']
111+
transpose_include_list = ['nnet_utils/nnet_transpose.h', 'nnet_utils/nnet_transpose_stream.h']
112112

113113
class TransposeConfigTemplate(LayerConfigTemplate):
114114
def __init__(self):

hls4ml/templates/quartus/firmware/defines.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,6 @@ using stream_out = ihc::stream_out<T>;
5050

5151
#define DIV_ROUNDUP(n,d) ((n + d - 1) / d)
5252
#define MIN(n,d) (n > d ? d : n)
53+
#define MAX(n,d) (n < d ? d : n)
5354

5455
#endif

hls4ml/templates/quartus/firmware/nnet_utils/nnet_activation.h

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,10 +131,12 @@ enum class softmax_implementation {latency=0, legacy=1, stable=2, argmax=3};
131131
template<class data_T, typename CONFIG_T>
132132
inline unsigned softmax_stable_idx_from_real_val(const data_T x){
133133
// Number of address bits for table
134-
static constexpr int N = ceillog2(CONFIG_T::table_size);
134+
static constexpr int N = ceillog2(CONFIG_T::table_size);
135135

136136
// Slice the top N bits of the input
137-
hls_register ac_int<N, false> y = x.template slc<N>(x.width-N-1);
137+
hls_register ac_int<N, false> y = x.template slc<N>(x.width-N-1);
138+
// If x is the most negative value, the slice will be 0, so we need to set the 0-th bit to ensure correctness
139+
if (x != 0 && y == 0) y[0] = 1;
138140
return y.to_uint();
139141
}
140142

@@ -158,11 +160,18 @@ void softmax_stable(data_T data[CONFIG_T::n_in], res_T res[CONFIG_T::n_in]){
158160
Op_max<data_T> op_max;
159161
hls_register data_T x_max = reduce<data_T, CONFIG_T::n_in, Op_max<data_T>>(data, op_max);
160162

163+
// For the diffs, use the same type as the input but force rounding and saturation
164+
hls_register ac_fixed<data_T::width, data_T::i_width, true, AC_RND, AC_SAT> d_xi_xmax[CONFIG_T::n_in];
165+
#pragma unroll
166+
for(unsigned i = 0; i < CONFIG_T::n_in; i++){
167+
d_xi_xmax[i] = data[i] - x_max;
168+
}
169+
161170
// Calculate all the e^x's
162171
hls_register typename CONFIG_T::exp_table_t exp_res[CONFIG_T::n_in];
163172
#pragma unroll
164173
for(unsigned i = 0; i < CONFIG_T::n_in; i++) {
165-
exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(data[i] - x_max)];
174+
exp_res[i] = exp_table[softmax_stable_idx_from_real_val<data_T, CONFIG_T>(d_xi_xmax[i])];
166175
}
167176

168177
// Explicitly sum previously calculated exponentials with an adder tree
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
#ifndef NNET_CONV1D_STREAM_H_
2+
#define NNET_CONV1D_STREAM_H_
3+
4+
#include "nnet_types.h"
5+
#include "nnet_dense.h"
6+
7+
namespace nnet {
8+
9+
/*
10+
* void kernel_shift(shift_buffer, kernel_window)
11+
*
12+
* Args:
13+
* shift_buffer - array elements popped from the line the buffer during the shift line buffer operation
14+
* kernel_window - array of values from the input curently being convolved with the kernel
15+
*
16+
* Values from shift_buffer are inserted into kernel_window, updating the values to be convolved
17+
*/
18+
template <class data_T, typename CONFIG_T>
19+
void kernel_shift_1d(
20+
typename data_T::value_type shift_buffer[CONFIG_T::n_chan],
21+
typename data_T::value_type kernel_window[CONFIG_T::filt_width * CONFIG_T::n_chan]
22+
) {
23+
/*
24+
* Manually shift kernel_window by one step to the left
25+
* Not possible to use nnet::shift_reg<T, N> as the kernel window is convolved with the kernel weights using dense matrix multiplication
26+
* Dense matrix multiplication is only implemented for arrays
27+
* However, provided certain timing constrains are met, Intel HLS automatically infers a shift operation and implements kernel_window as a shift register
28+
* To verify, see synthesis report in report.html > Area Analysis of System
29+
*/
30+
KernelShiftWidth:
31+
#pragma unroll
32+
for (int col = 0; col < CONFIG_T::filt_width - 1; col++) {
33+
KernelShiftChannel:
34+
#pragma unroll
35+
for (int channel = 0; channel < CONFIG_T::n_chan; channel++) {
36+
kernel_window[col * CONFIG_T::n_chan + channel] = kernel_window[(col + 1) * CONFIG_T::n_chan + channel];
37+
}
38+
}
39+
40+
// Insert shift_buffer values into the last column of the kernel window
41+
KernelPushChannel:
42+
#pragma unroll
43+
for (int channel = 0; channel < CONFIG_T::n_chan; channel++) {
44+
kernel_window[(CONFIG_T::filt_width - 1) * CONFIG_T::n_chan + channel] = shift_buffer[channel];
45+
}
46+
}
47+
48+
/*
49+
* void shift_line_buffer(in_element, line_buffer, shift_buffer)
50+
*
51+
* Args:
52+
* in_element - current elements from input image, data_T type is usually nnet::array, size of array corresponds to number of channels
53+
* line_buffer - chained array of shift registers, one for each row of the kernel and channel
54+
* shift_buffer - array elements popped from the line the buffer during the shift operation
55+
*
56+
* Values from in_element are inserted into the line buffer, causing all other elements to be shifted by one
57+
* Popped elements are later used to update the kernel window, during the kernel_shift operation
58+
*/
59+
template <class data_T, typename CONFIG_T>
60+
void shift_line_buffer_1d(
61+
const data_T &in_elem,
62+
nnet::shift_reg<typename data_T::value_type, CONFIG_T::pad_left + CONFIG_T::in_width + CONFIG_T::pad_right> line_buffer[CONFIG_T::n_chan],
63+
typename data_T::value_type shift_buffer[CONFIG_T::n_chan]
64+
) {
65+
// For every channel, insert the incoming pixel at end of the shift buffer
66+
UpdateBuffer:
67+
#pragma unroll
68+
for (int channel = 0; channel < CONFIG_T::n_chan; channel++) {
69+
shift_buffer[channel] = in_elem[channel];
70+
}
71+
}
72+
73+
/*
74+
* void compute_output_buffer(in_element, res_stream, line_buffer, kernel_window, weights, biases)
75+
*
76+
* Args:
77+
* in_element - current elements from input image, data_T type is usually nnet::array, size of array corresponds to number of channels
78+
* res_stream - output stream, passed by reference to allow direct writing
79+
* line_buffer - chained array of shift registers, one for each row of the kernel and channel
80+
* kernel_window - array of values from the input curently convolved with the kernel
81+
* weights - Conv1D layer weights
82+
* biases - Conv1D layer biases
83+
*
84+
* Function executes 4 steps:
85+
* (1) Shift line buffer - updates the contents of the chained shift registers, inserting the new inputs and removing last elements
86+
* (2) Kernel shift - updates the elements of the kernel window, by storing the new inputs and popped elements from the line buffer
87+
* (3) Matrix mulitplication - performs dense matrix multiplication between the current input window and kernel weights
88+
* (4) Counter housekeeping - keeps track of current pixel and stride
89+
*/
90+
template<class data_T, class res_T, typename CONFIG_T>
91+
void compute_output_buffer_1d(
92+
const data_T &in_elem,
93+
stream<res_T> &res_stream,
94+
nnet::shift_reg<typename data_T::value_type, CONFIG_T::pad_left + CONFIG_T::in_width + CONFIG_T::pad_right> line_buffer[CONFIG_T::n_chan],
95+
typename data_T::value_type kernel_window[CONFIG_T::filt_width * CONFIG_T::n_chan],
96+
const typename CONFIG_T::weight_t weights[CONFIG_T::kernel_size * CONFIG_T::n_chan * CONFIG_T::n_filt],
97+
const typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]
98+
) {
99+
// Thresholds
100+
static constexpr int lShiftX = CONFIG_T::filt_width - 1;
101+
102+
// X position pixel
103+
static int pX = 0;
104+
105+
// X strides
106+
static int sX = 0;
107+
108+
// Step 1 - Shift line buffer
109+
hls_register typename data_T::value_type shift_buffer[CONFIG_T::n_chan];
110+
nnet::shift_line_buffer_1d<data_T, CONFIG_T>(in_elem, line_buffer, shift_buffer);
111+
112+
// Step 2 - Kernel shift
113+
nnet::kernel_shift_1d<data_T, CONFIG_T>(shift_buffer, kernel_window);
114+
115+
// Check to see if we have a full kernel
116+
if ((sX - lShiftX) == 0 && pX > (lShiftX - 1)) {
117+
// Step 3 - Dense matrix multiplication
118+
hls_register typename res_T::value_type res_out[CONFIG_T::n_filt];
119+
dense_resource<typename data_T::value_type, typename res_T::value_type, typename CONFIG_T::mult_config>(kernel_window, res_out, weights, biases);
120+
121+
// Write result to output stream
122+
hls_register res_T res_pack;
123+
CastLoop:
124+
#pragma unroll
125+
for (int channel = 0; channel < CONFIG_T::n_filt; channel++) {
126+
res_pack[channel] = res_out[channel];
127+
}
128+
res_stream.write(res_pack);
129+
}
130+
131+
// Reached end of image
132+
if ((pX + 1) == (CONFIG_T::in_width + CONFIG_T::pad_left + CONFIG_T::pad_right)) {
133+
pX = 0;
134+
sX = 0;
135+
// Move to the right
136+
} else {
137+
pX++;
138+
sX = ((sX - lShiftX) == 0) ? (sX - CONFIG_T::stride_width + 1) : (sX + 1);
139+
}
140+
}
141+
142+
143+
template <class data_T, class res_T, typename CONFIG_T>
144+
void conv_1d_cl(
145+
stream<data_T> &data,
146+
stream<res_T> &res,
147+
const typename CONFIG_T::weight_t weights[CONFIG_T::filt_width * CONFIG_T::n_chan * CONFIG_T::n_filt],
148+
const typename CONFIG_T::bias_t biases[CONFIG_T::n_filt]
149+
) {
150+
// Line buffer and kernel window
151+
hls_register static nnet::shift_reg<typename data_T::value_type, CONFIG_T::pad_left + CONFIG_T::in_width + CONFIG_T::pad_right> line_buffer[CONFIG_T::n_chan];
152+
hls_register static typename data_T::value_type kernel_window[CONFIG_T::filt_width * CONFIG_T::n_chan];
153+
154+
// An array of length CONFIG_T::n_chan, with elements set to zero (padding for each channel)
155+
static const data_T padds(0);
156+
157+
// Input image left-side padding
158+
PaddingLeftWidth:
159+
for (int col = 0; col < CONFIG_T::pad_left; col++) {
160+
compute_output_buffer_1d<data_T, res_T, CONFIG_T>(padds, res, line_buffer, kernel_window, weights, biases);
161+
}
162+
163+
// Read input image
164+
ReadInputWidth:
165+
for (int col = 0; col < CONFIG_T::in_width; col++) {
166+
compute_output_buffer_1d<data_T, res_T, CONFIG_T>(data.read(), res, line_buffer, kernel_window, weights, biases);
167+
}
168+
169+
// Input image right-side padding
170+
PaddingRightWidth:
171+
for (int col = 0; col < CONFIG_T::pad_right; col++) {
172+
compute_output_buffer_1d<data_T, res_T, CONFIG_T>(padds, res, line_buffer, kernel_window, weights, biases);
173+
}
174+
}
175+
176+
}
177+
178+
#endif

0 commit comments

Comments
 (0)