Skip to content

Commit 90d760a

Browse files
authored
Merge pull request #600 from vloncar/instruct_cnn
Unrolled CNN implementation
2 parents ee891c3 + cd915eb commit 90d760a

File tree

19 files changed

+658
-834
lines changed

19 files changed

+658
-834
lines changed

hls4ml/backends/fpga/fpga_backend.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,27 @@ def set_target_reuse_factor(self, layer):
181181

182182
layer.set_attr('reuse_factor', float(rf) / kernel_multiplies)
183183

184+
def get_valid_conv_partition_splits(self, out_height, out_width):
185+
"""Generate valid partition splits of a Conv1D/2D layer.
186+
187+
Essentially a list of divisors of the number of pixels of the output image.
188+
189+
Args:
190+
out_height (int): The height of the output image
191+
out_width (int): The width of the output image
192+
193+
Returns:
194+
list: List of valid partition splits
195+
"""
196+
n_pixels = out_height * out_width
197+
valid_n_partitions = []
198+
for i in range(1, int(n_pixels / 2) + 1):
199+
if n_pixels % i == 0:
200+
valid_n_partitions.append(i)
201+
valid_n_partitions.append(n_pixels)
202+
203+
return valid_n_partitions
204+
184205
@classmethod
185206
def convert_precision_string(cls, precision):
186207
if isinstance(precision, IntegerPrecisionType) or isinstance(precision, FixedPrecisionType):
@@ -384,6 +405,223 @@ def compute_conv2d_instructions(self, in_H, in_W, in_C, kernel_size=3, stride=1,
384405

385406
return (min_H, min_W, windows_int)
386407

408+
def _compute_conv1d_im2col(self, input_shape, kernel=3, stride=1, pad=(0,0), dilation=1):
409+
W, C = input_shape
410+
pad_l, pad_r = pad
411+
412+
out_w = (W + pad_l + pad_r - (dilation * (kernel - 1) + 1)) // stride + 1
413+
414+
input_img = np.arange(1, W * C + 1)
415+
im_matrix = np.zeros((kernel * C * out_w, ))
416+
417+
index = 0
418+
for i_ow in range(out_w):
419+
for i_kw in range(kernel):
420+
for i_c in range(C):
421+
input_col = -pad_l + i_kw * dilation + i_ow * stride
422+
if (input_col >= 0 and input_col < W):
423+
im_matrix[index] = input_img[input_col * C + i_c]
424+
else:
425+
im_matrix[index] = 0
426+
index += 1
427+
428+
im_matrix = im_matrix.reshape(out_w, -1)
429+
return im_matrix
430+
431+
432+
def generate_conv1d_line_buffer_fn(self, layer_idx, n_partitions, in_W, in_C, kernel=3, stride=1, pad=0, dilation=1):
433+
"""Generate a C++ function that mimics the im2col algorithm. This function works for 1D convolution.
434+
435+
The HLS compiler produces suboptimal designs for a im2col algorithm implementation, so a trick we use is
436+
to generate a resulting a result of im2col transformation explicitly, instead of relying on loops. Since
437+
the result depends on the paraleters of the convolution layer (the input size, the kernel size, stride etc),
438+
we need to do this for every convolution layer.
439+
440+
Args:
441+
layer_idx (int): Index of layer ('index' attribute).
442+
n_partitions (int): Number of partitions to divide the input into. The pixels in each partition will be processed in parallel.
443+
in_W (int): Width of input.
444+
in_C (int): Number of channels.
445+
kernel (int, optional): Size of the kernel. Defaults to 3.
446+
stride (int, optional): Stride length. Defaults to 1.
447+
pad (int or Iterable, optional): Padding to apply. Specified as either a number or a list [left_pad, right_pad]. Defaults to 0.
448+
dilation (int, optional): Dilation rate. Defaults to 1.
449+
450+
Returns:
451+
str: Generated C++ function
452+
"""
453+
if isinstance(pad, Iterable):
454+
pad_left = pad[0]
455+
pad_right = pad[1]
456+
else:
457+
pad_left = pad
458+
pad_right = pad
459+
460+
im2col_matrix = self._compute_conv1d_im2col(
461+
(in_W, in_C),
462+
kernel,
463+
stride,
464+
(pad_left, pad_right),
465+
dilation
466+
)
467+
468+
generated_code = (
469+
"template<class data_T, typename CONFIG_T>\n"
470+
"class fill_buffer_{index} : public FillConv1DBuffer<data_T, CONFIG_T> {{\n"
471+
" public:\n"
472+
" static void fill_buffer(\n"
473+
" data_T data[CONFIG_T::in_width * CONFIG_T::n_chan],\n"
474+
" data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_width * CONFIG_T::n_chan],\n"
475+
" const unsigned partition\n"
476+
" ) {{\n"
477+
).format(index=layer_idx)
478+
indent = ' '
479+
480+
for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)):
481+
generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx)
482+
for pixel_idx, arr in enumerate(partition):
483+
buffer_stmts = []
484+
for j, v in enumerate(arr):
485+
if v == 0:
486+
val = '0'
487+
else:
488+
val = 'data[{}]'.format(int(v-1))
489+
buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val))
490+
generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n'
491+
generated_code += '\n' + indent * 2 + '}\n'
492+
493+
generated_code += indent + '}\n'
494+
generated_code += '};\n'
495+
496+
return generated_code
497+
498+
def _compute_conv2d_im2col(self, input_shape, kernel=(3, 3), stride=(1, 1), pad=(0, 0, 0, 0), dilation=(1,1)):
499+
H, W, C = input_shape
500+
kernel_h, kernel_w = kernel
501+
stride_h, stride_w = stride
502+
pad_t, pad_b, pad_l, pad_r = pad
503+
dilation_h, dilation_w = dilation
504+
505+
out_h = (H + pad_t + pad_b - (dilation_h * (kernel_h - 1) + 1)) // stride_h + 1
506+
out_w = (W + pad_l + pad_r - (dilation_w * (kernel_w - 1) + 1)) // stride_w + 1
507+
508+
input_img = np.arange(1, H * W * C + 1)
509+
im_matrix = np.zeros((kernel_h * kernel_w * C * out_h * out_w, ))
510+
511+
index = 0
512+
for i_oh in range(out_h):
513+
for i_ow in range(out_w):
514+
for i_kh in range(kernel_h):
515+
input_row = -pad_t + i_kh * dilation_h + i_oh * stride_h
516+
for i_kw in range(kernel_w):
517+
for i_c in range(C):
518+
if (input_row < 0 or input_row >= H):
519+
im_matrix[index] = 0
520+
else:
521+
input_col = -pad_l + i_kw * dilation_w + i_ow * stride_w
522+
if (input_col >= 0 and input_col < W):
523+
im_matrix[index] = input_img[input_row * W * C + input_col * C + i_c]
524+
else:
525+
im_matrix[index] = 0
526+
index += 1
527+
528+
im_matrix = im_matrix.reshape(out_h * out_w, -1)
529+
return im_matrix
530+
531+
532+
def generate_conv2d_line_buffer_fn(self, layer_idx, n_partitions, in_H, in_W, in_C, kernel=(3, 3), stride=(1, 1), pad=(0, 0, 0, 0), dilation=(1, 1)):
533+
"""Generate a C++ function that mimics the im2col algorithm. This function works for 2D convolution.
534+
535+
The HLS compiler produces suboptimal designs for a im2col algorithm implementation, so a trick we use is
536+
to generate a resulting a result of im2col transformation explicitly, instead of relying on loops. Since
537+
the result depends on the paraleters of the convolution layer (the input size, the kernel size, stride etc),
538+
we need to do this for every convolution layer.
539+
540+
Args:
541+
layer_idx (int): Index of layer ('index' attribute).
542+
n_partitions (int): Number of partitions to divide the input into. The pixels in each partition will be processed in parallel.
543+
in_H (int): Height of input.
544+
in_W (int): Width of input.
545+
in_C (int): Number of channels.
546+
kernel (int or Iterable, optional): Size of the kernel. Defaults to (3,3).
547+
stride (int or Iterable, optional): Stride length. Defaults to (1,1).
548+
pad (int or Iterable, optional): Padding to apply. Specified as either a number or a list [top_pad, bottom_pad, left_pad, right_pad]. Defaults to 0.
549+
dilation (int or Iterable, optional): Dilation rate. Defaults to (1,1).
550+
551+
Returns:
552+
str: Generated C++ function
553+
"""
554+
555+
if isinstance(kernel, Iterable):
556+
kernel_height = kernel[0]
557+
kernel_width = kernel[1]
558+
else:
559+
kernel_height = kernel
560+
kernel_width = kernel
561+
562+
if isinstance(stride, Iterable):
563+
stride_height = stride[0]
564+
stride_width = stride[1]
565+
else:
566+
stride_height = stride
567+
stride_width = stride
568+
569+
if isinstance(pad, Iterable):
570+
pad_top = pad[0]
571+
pad_bottom = pad[1]
572+
pad_left = pad[2]
573+
pad_right = pad[3]
574+
else:
575+
pad_top = pad
576+
pad_bottom = pad
577+
pad_left = pad
578+
pad_right = pad
579+
580+
if isinstance(dilation, Iterable):
581+
dilation_height = dilation[0]
582+
dilation_width = dilation[1]
583+
else:
584+
dilation_height = dilation
585+
dilation_width = dilation
586+
587+
im2col_matrix = self._compute_conv2d_im2col(
588+
(in_H, in_W, in_C),
589+
(kernel_height, kernel_width),
590+
(stride_height, stride_width),
591+
(pad_top, pad_bottom, pad_left, pad_right),
592+
(dilation_height, dilation_width)
593+
)
594+
595+
generated_code = (
596+
"template<class data_T, typename CONFIG_T>\n"
597+
"class fill_buffer_{index} : public FillConv2DBuffer<data_T, CONFIG_T> {{\n"
598+
" public:\n"
599+
" static void fill_buffer(\n"
600+
" data_T data[CONFIG_T::in_height * CONFIG_T::in_width * CONFIG_T::n_chan],\n"
601+
" data_T buffer[CONFIG_T::n_pixels][CONFIG_T::filt_height * CONFIG_T::filt_width * CONFIG_T::n_chan],\n"
602+
" const unsigned partition\n"
603+
" ) {{\n"
604+
).format(index=layer_idx)
605+
indent = ' '
606+
607+
for partition_idx, partition in enumerate(np.split(im2col_matrix, n_partitions)):
608+
generated_code += indent * 2 + 'if (partition == {:>3}) {{\n'.format(partition_idx)
609+
for pixel_idx, arr in enumerate(partition):
610+
buffer_stmts = []
611+
for j, v in enumerate(arr):
612+
if v == 0:
613+
val = '0'
614+
else:
615+
val = 'data[{}]'.format(int(v-1))
616+
buffer_stmts.append('buffer[{}][{}] = {:>10};'.format(pixel_idx, j, val))
617+
generated_code += indent * 3 + ' '.join(buffer_stmts) + '\n'
618+
generated_code += '\n' + indent * 2 + '}\n'
619+
620+
generated_code += indent + '}\n'
621+
generated_code += '};\n'
622+
623+
return generated_code
624+
387625
@model_optimizer()
388626
def write_hls(self, model):
389627
self.writer.write_hls(model)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from hls4ml.model.optimizer import OptimizerPass
2+
from hls4ml.model.layers import Conv1D, Conv2D
3+
from hls4ml.model.types import Source
4+
5+
class GenerateConvIm2col(OptimizerPass):
6+
''' Generates tcode for im2col step of 1D/2d convolution '''
7+
def match(self, node):
8+
return isinstance(node, (Conv1D, Conv2D)) and \
9+
node.model.config.get_config_value('IOType') == 'io_parallel'
10+
11+
def transform(self, model, node):
12+
node_class = node.__class__.__name__
13+
if '1D' in node_class:
14+
self._generate_im2col_1d(node)
15+
elif '2D' in node_class:
16+
self._generate_im2col_2d(node)
17+
else:
18+
raise Exception('Cannot generate instructions for node {} ({})'.format(node.name, node_class))
19+
20+
def _generate_im2col_1d(self, node):
21+
code_str = node.model.config.backend.generate_conv1d_line_buffer_fn(
22+
node.get_attr('index'),
23+
node.get_attr('n_partitions'),
24+
node.get_input_variable().shape[0],
25+
node.get_input_variable().shape[1],
26+
kernel=node.get_attr('filt_width'),
27+
stride=node.get_attr('stride_width'),
28+
pad=(node.get_attr('pad_left'), node.get_attr('pad_right'))
29+
)
30+
31+
node.set_attr('line_buffer_codegen', Source(code_str))
32+
33+
def _generate_im2col_2d(self, node):
34+
code_str = node.model.config.backend.generate_conv2d_line_buffer_fn(
35+
node.get_attr('index'),
36+
node.get_attr('n_partitions'),
37+
node.get_input_variable().shape[0],
38+
node.get_input_variable().shape[1],
39+
node.get_input_variable().shape[2],
40+
kernel=(node.get_attr('filt_height'), node.get_attr('filt_width')),
41+
stride=(node.get_attr('stride_height'), node.get_attr('stride_width')),
42+
pad=(node.get_attr('pad_top'), node.get_attr('pad_bottom'), node.get_attr('pad_left'), node.get_attr('pad_right'))
43+
)
44+
45+
node.set_attr('line_buffer_codegen', Source(code_str))

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ def init_conv1d(self, layer):
255255
# - Winograd - use Winograd, if possible
256256
layer.set_attr('implementation', layer.model.config.get_layer_config_value(layer, 'Implementation', 'combination'))
257257

258+
layer.set_attr('n_partitions', 1) #TODO Not used yet as there is no codegen implementation of CNNs for Quartus backend
259+
258260
@layer_optimizer(Conv2D)
259261
def init_conv2d(self, layer):
260262
# This can happen if we assign weights of Dense layer to 1x1 Conv2D
@@ -281,6 +283,8 @@ def init_conv2d(self, layer):
281283
# - im2col - specifically use im2col
282284
# - Winograd - use Winograd, if possible
283285
layer.set_attr('implementation', layer.model.config.get_layer_config_value(layer, 'Implementation', 'combination'))
286+
287+
layer.set_attr('n_partitions', 1) #TODO Not used yet as there is no codegen implementation of CNNs for Quartus backend
284288

285289
@layer_optimizer(LSTM)
286290
def init_lstm(self, layer):

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@
3737
static const nnet::conv_implementation implementation = nnet::conv_implementation::{implementation};
3838
static const unsigned min_width = {min_width};
3939
static const ap_uint<filt_width> pixels[min_width];
40+
static const unsigned n_partitions = {n_partitions};
41+
static const unsigned n_pixels = out_width / n_partitions;
42+
template<class data_T, class CONFIG_T>
43+
using fill_buffer = nnet::{fill_fn}<data_T, CONFIG_T>;
4044
typedef {accum_t.name} accum_t;
4145
typedef {bias_t.name} bias_t;
4246
typedef {weight_t.name} weight_t;
@@ -60,6 +64,10 @@ def format(self, node):
6064
params['nzeros'] = node.get_weights('weight').nzeros
6165

6266
params['config_t'] = 'config{}_mult'.format(node.index)
67+
if node.model.config.get_config_value('IOType') == 'io_parallel':
68+
params['fill_fn'] = 'fill_buffer_{}'.format(node.index)
69+
else:
70+
params['fill_fn'] = 'FillConv1DBuffer'
6371
conv_config = self.template.format(**params)
6472

6573
mult_params = self._default_config_params(node)
@@ -109,6 +117,10 @@ def format(self, node):
109117
static const unsigned min_height = {min_height};
110118
static const unsigned min_width = {min_width};
111119
static const ap_uint<filt_height * filt_width> pixels[min_height * min_width];
120+
static const unsigned n_partitions = {n_partitions};
121+
static const unsigned n_pixels = out_height * out_width / n_partitions;
122+
template<class data_T, class CONFIG_T>
123+
using fill_buffer = nnet::{fill_fn}<data_T, CONFIG_T>;
112124
typedef {accum_t.name} accum_t;
113125
typedef {bias_t.name} bias_t;
114126
typedef {weight_t.name} weight_t;
@@ -133,6 +145,10 @@ def format(self, node):
133145
params['nzeros'] = node.get_weights('weight').nzeros
134146

135147
params['config_t'] = 'config{}_mult'.format(node.index)
148+
if node.model.config.get_config_value('IOType') == 'io_parallel':
149+
params['fill_fn'] = 'fill_buffer_{}'.format(node.index)
150+
else:
151+
params['fill_fn'] = 'FillConv2DBuffer'
136152
conv_config = self.template.format(**params)
137153

138154
mult_params = self._default_config_params(node)
@@ -198,6 +214,7 @@ def format(self, node):
198214
params['nzeros'] = node.get_weights('depthwise').nzeros
199215
params['index'] = str(node.index) + '_depthwise'
200216
params['weight_t'] = node.get_weights('depthwise').type
217+
params['fill_fn'] = 'FillConv1DBuffer'
201218

202219
params['config_t'] = 'config{}_depthwise_mult'.format(node.index)
203220
depthwise_config = self.depthwise_template.format(**params)
@@ -229,6 +246,7 @@ def format(self, node):
229246
params['weight_t'] = node.get_weights('pointwise').type
230247
params['min_width'] = params['in_width']
231248
params['instructions'] = '0'
249+
params['fill_fn'] = 'FillConv1DBuffer'
232250

233251
params['config_t'] = 'config{}_pointwise_mult'.format(node.index)
234252
pointwise_config = self.pointwise_template.format(**params)
@@ -283,6 +301,7 @@ def format(self, node):
283301
params['nzeros'] = node.get_weights('depthwise').nzeros
284302
params['index'] = str(node.index) + '_depthwise'
285303
params['weight_t'] = node.get_weights('depthwise').type
304+
params['fill_fn'] = 'FillConv2DBuffer'
286305

287306
params['config_t'] = 'config{}_depthwise_mult'.format(node.index)
288307
depthwise_config = self.depthwise_template.format(**params)
@@ -314,6 +333,7 @@ def format(self, node):
314333
params['min_height'] = params['in_height']
315334
params['min_width'] = params['in_width']
316335
params['instructions'] = '0'
336+
params['fill_fn'] = 'FillConv2DBuffer'
317337

318338
params['config_t'] = 'config{}_pointwise_mult'.format(node.index)
319339
pointwise_config = self.pointwise_template.format(**params)

0 commit comments

Comments
 (0)