Skip to content

Commit d6957bd

Browse files
committed
add general transpose for vivado/vitis
1 parent 2c17f66 commit d6957bd

File tree

6 files changed

+158
-92
lines changed

6 files changed

+158
-92
lines changed

hls4ml/backends/vivado/passes/reshaping_templates.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
from math import prod
2+
3+
import numpy as np
4+
15
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
26
from hls4ml.model.layers import Resize, Transpose, ZeroPadding1D, ZeroPadding2D
37

@@ -97,16 +101,45 @@ def format(self, node):
97101

98102
# Transpose templates
99103

100-
transpose_config_template = """struct config{index} : nnet::transpose_config {{
101-
static const unsigned depth = {depth};
102-
static const unsigned height = {height};
103-
static const unsigned width = {width};
104-
static constexpr unsigned perm[3] = {{{perm_str}}};
105-
}};\n"""
106104

107-
transpose_function_template = 'nnet::transpose_{dim}<{input_t}, {output_t}, {config}>({input}, {output});'
105+
transpose_include_list = ['nnet_utils/nnet_transpose.h', 'nnet_utils/nnet_transpose_stream.h']
106+
107+
transpose_config_template = """struct {config_name} {{
108+
static const unsigned dims = {dims};
109+
static const unsigned N = {N};
110+
static const unsigned* const from_shape;
111+
static const unsigned* const to_shape;
112+
static const unsigned* const perm;
113+
static const unsigned* const perm_strides;
114+
}};
115+
116+
unsigned {config_name}_from_shape[{dims}] = {{{from_shape}}};
117+
unsigned {config_name}_to_shape[{dims}] = {{{to_shape}}};
118+
unsigned {config_name}_perm[{dims}] = {{{perm}}};
119+
unsigned {config_name}_perm_strides[{dims}] = {{{perm_strides}}};
120+
121+
const unsigned* const {config_name}::from_shape = {config_name}_from_shape;
122+
const unsigned* const {config_name}::to_shape = {config_name}_to_shape;
123+
const unsigned* const {config_name}::perm = {config_name}_perm;
124+
const unsigned* const {config_name}::perm_strides = {config_name}_perm_strides;
125+
"""
126+
127+
transpose_function_template = 'nnet::transpose<{input_t}, {output_t}, {config_name}>({input}, {output});'
108128

109-
transpose_include_list = ['nnet_utils/nnet_array.h', 'nnet_utils/nnet_stream.h']
129+
130+
def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]):
131+
new_shape = tuple(shape[i] for i in perm)
132+
strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1]
133+
perm_strides = tuple(int(strides[i]) for i in perm)
134+
return transpose_config_template.format(
135+
dims=len(shape),
136+
N=prod(shape),
137+
from_shape=', '.join(str(x) for x in shape),
138+
perm=', '.join(str(x) for x in perm),
139+
perm_strides=', '.join(str(x) for x in perm_strides),
140+
to_shape=', '.join(str(x) for x in new_shape),
141+
config_name=name,
142+
)
110143

111144

112145
class TransposeConfigTemplate(LayerConfigTemplate):
@@ -115,18 +148,18 @@ def __init__(self):
115148
self.template = transpose_config_template
116149

117150
def format(self, node):
118-
params = self._default_config_params(node)
119-
120-
return self.template.format(**params)
151+
shape = tuple(node.get_input_variable().shape)
152+
perm = tuple(node.get_attr('perm'))
153+
name = f'config{node.index}'
154+
return permute_config_gen(name, shape, perm)
121155

122156

123157
class TransposeFunctionTemplate(FunctionCallTemplate):
124158
def __init__(self):
125-
super().__init__(Transpose, include_header=transpose_include_list)
126159
self.template = transpose_function_template
160+
super().__init__(Transpose, include_header=transpose_include_list)
127161

128162
def format(self, node):
129163
params = self._default_function_params(node)
130-
params['dim'] = node.get_attr('dim')
131-
164+
params['config_name'] = f'config{node.index}'
132165
return self.template.format(**params)

hls4ml/model/layers.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,8 +1167,7 @@ def initialize(self):
11671167
perm = self.get_attr('perm')
11681168
self.set_attr('dim', f'{len(inp.shape)}d')
11691169

1170-
if len(perm) > 3:
1171-
raise Exception('ERROR: Transpose of tensors with rank > 3 is not yet supported.')
1170+
# TODO: dim>3 is only supported for vivado/vitis backend
11721171

11731172
# ONNX double transpose specific, sometimes ONNX injects
11741173
# useless double transpose layers when converting
@@ -1188,11 +1187,14 @@ def initialize(self):
11881187
self.set_attr('depth', 1)
11891188
self.set_attr('height', inp.shape[0])
11901189
self.set_attr('width', inp.shape[1])
1191-
elif len(shape) > 2:
1190+
elif len(shape) == 3:
11921191
dims = [f'OUT_DEPTH_{self.index}', f'OUT_HEIGHT_{self.index}', f'OUT_WIDTH_{self.index}']
11931192
self.set_attr('depth', inp.shape[0])
11941193
self.set_attr('height', inp.shape[1])
11951194
self.set_attr('width', inp.shape[2])
1195+
elif len(shape) > 3:
1196+
# Differentiate between 2/3/3+ dim does not really appear to be needed. To be removed?
1197+
dims = [f'OUT_DIM_{i}_{self.index}' for i in range(1, len(shape) + 1)]
11961198
self.add_output_variable(shape, dims, precision=inp.type.precision)
11971199

11981200

hls4ml/templates/vivado/nnet_utils/nnet_array.h

Lines changed: 0 additions & 52 deletions
This file was deleted.

hls4ml/templates/vivado/nnet_utils/nnet_stream.h

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -179,29 +179,6 @@ void broadcast_stream(hls::stream<data_T> &data, hls::stream<res_T> &res) {
179179
}
180180
}
181181

182-
template <class data_T, class res_T, typename CONFIG_T>
183-
void transpose_2d(hls::stream<data_T> &data, hls::stream<res_T> &res) {
184-
typename data_T::value_type data_array[CONFIG_T::height * CONFIG_T::width];
185-
#pragma HLS ARRAY_PARTITION variable=data_array complete
186-
187-
for (int i = 0; i < CONFIG_T::height * CONFIG_T::width / data_T::size; i++) {
188-
#pragma HLS PIPELINE
189-
data_T in_data = data.read();
190-
for (int j = 0; j < data_T::size; j++) {
191-
data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]);
192-
}
193-
}
194-
195-
for (int i = 0; i < CONFIG_T::height * CONFIG_T::width / res_T::size; i++) {
196-
#pragma HLS PIPELINE
197-
res_T out_data;
198-
PRAGMA_DATA_PACK(out_data)
199-
for (int j = 0; j < res_T::size; j++) {
200-
out_data[j] = typename res_T::value_type(data_array[j * data_T::size + i]);
201-
}
202-
res.write(out_data);
203-
}
204-
}
205182
} // namespace nnet
206183

207184
#endif
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#ifndef NNET_PERMUTE_H_
2+
#define NNET_PERMUTE_H_
3+
4+
namespace nnet {
5+
6+
struct transpose_config {
7+
static const unsigned dims;
8+
static const unsigned N;
9+
// vivado/vitis hls can't index constexpr array for some reason
10+
// and vivado hls don't like template recursion either (vitis is fine)
11+
// thus this appears to be the only workaround (or overkill it with codegen)
12+
static const unsigned *const from_shape;
13+
static const unsigned *const to_shape;
14+
static const unsigned *const perm;
15+
static const unsigned *const perm_strides;
16+
};
17+
18+
template <typename CONFIG_T> unsigned transfer_idx(int index) {
19+
// Given output idx in c-order flat array, return input idx
20+
int idx = 0;
21+
for (int i = CONFIG_T::dims - 1; i >= 0; i--) {
22+
idx += (index % CONFIG_T::to_shape[i]) * CONFIG_T::perm_strides[i];
23+
index /= CONFIG_T::to_shape[i];
24+
}
25+
return idx;
26+
}
27+
28+
template <typename data_T, typename res_T, typename CONFIG_T>
29+
void transpose(const data_T data[CONFIG_T::N], res_T res[CONFIG_T::N]) {
30+
for (int i = 0; i < CONFIG_T::N; i++) {
31+
#pragma HLS UNROLL
32+
int idx = transfer_idx<CONFIG_T>(i);
33+
res[i] = data[idx];
34+
}
35+
}
36+
37+
} // namespace nnet
38+
39+
#endif
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#ifndef NNET_TRANSPOSE_STREAM_H
2+
#define NNET_TRANSPOSE_STREAM_H
3+
4+
#include "hls_stream.h"
5+
#include "nnet_transpose.h"
6+
#include <type_traits>
7+
8+
namespace nnet {
9+
10+
template <typename data_T, typename res_T, typename CONFIG_T>
11+
typename std::enable_if<CONFIG_T::dims == 2, void>::type transpose(hls::stream<data_T> &data, hls::stream<res_T> &res) {
12+
// #pragma HLS INLINE RECURSIVE
13+
typename data_T::value_type data_array[CONFIG_T::N];
14+
#pragma HLS ARRAY_PARTITION variable=data_array complete
15+
16+
for (int i = 0; i < CONFIG_T::N / data_T::size; i++) {
17+
#pragma HLS PIPELINE
18+
data_T in_data = data.read();
19+
for (int j = 0; j < data_T::size; j++) {
20+
#pragma HLS UNROLL
21+
data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]);
22+
}
23+
}
24+
25+
for (int i = 0; i < CONFIG_T::N / res_T::size; i++) {
26+
#pragma HLS PIPELINE
27+
res_T out_data;
28+
PRAGMA_DATA_PACK(out_data)
29+
for (int j = 0; j < res_T::size; j++) {
30+
#pragma HLS UNROLL
31+
out_data[j] = typename res_T::value_type(data_array[j * CONFIG_T::from_shape[1] + i]);
32+
}
33+
res.write(out_data);
34+
}
35+
}
36+
37+
// This sfinae is for vivado_hls, which has some overhead using the transfer_idx in io_stream.
38+
// In vitis both performs exactly the same, thus this is not removed out of convenience.
39+
template <typename data_T, typename res_T, typename CONFIG_T>
40+
typename std::enable_if<CONFIG_T::dims != 2, void>::type transpose(hls::stream<data_T> &data, hls::stream<res_T> &res) {
41+
// #pragma HLS INLINE RECURSIVE
42+
typename data_T::value_type data_array[CONFIG_T::N];
43+
#pragma HLS ARRAY_PARTITION variable=data_array complete
44+
45+
for (int i = 0; i < CONFIG_T::N / data_T::size; i++) {
46+
#pragma HLS PIPELINE
47+
data_T in_data = data.read();
48+
for (int j = 0; j < data_T::size; j++) {
49+
#pragma HLS UNROLL
50+
data_array[i * data_T::size + j] = typename data_T::value_type(in_data[j]);
51+
}
52+
}
53+
54+
for (int i = 0; i < CONFIG_T::N / res_T::size; i++) {
55+
#pragma HLS PIPELINE
56+
res_T out_data;
57+
PRAGMA_DATA_PACK(out_data)
58+
for (int j = 0; j < res_T::size; j++) {
59+
#pragma HLS UNROLL
60+
out_data[j] = typename res_T::value_type(data_array[transfer_idx<CONFIG_T>(i * res_T::size + j)]);
61+
}
62+
res.write(out_data);
63+
}
64+
}
65+
66+
} // namespace nnet
67+
#endif

0 commit comments

Comments
 (0)