Skip to content

Commit 158fa4f

Browse files
committed
general einsum support for io_parallel and latency
1 parent 0bc8f03 commit 158fa4f

File tree

7 files changed

+579
-3
lines changed

7 files changed

+579
-3
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
from hls4ml.backends.backend import get_backend
2+
from hls4ml.backends.template import FunctionCallTemplate, LayerConfigTemplate
3+
from hls4ml.model.layers import EinsumDense
4+
5+
from .reshaping_templates import transpose_config_gen
6+
7+
# Shared Dense template
8+
9+
conv_dense_config_template = """struct config{index}_dense : nnet::dense_config {{
10+
static const unsigned n_in = {n_in};
11+
static const unsigned n_out = {n_out};
12+
static const unsigned reuse_factor = {reuse};
13+
static const unsigned strategy = nnet::{strategy};
14+
static const unsigned n_zeros = {nzeros};
15+
static const unsigned multiplier_limit = DIV_ROUNDUP(n_in * n_out, reuse_factor) - n_zeros / reuse_factor;
16+
typedef {accum_t.name} accum_t;
17+
typedef {bias_t.name} bias_t;
18+
typedef {weight_t.name} weight_t;
19+
template<class data_T, class res_T, class CONFIG_T>
20+
using kernel = nnet::{dense_function}<data_T, res_T, CONFIG_T>;
21+
template<class x_T, class y_T>
22+
using product = nnet::product::{product_type}<x_T, y_T>;
23+
}};\n"""
24+
25+
# EinsumDense template
26+
27+
einsum_dense_config_template = '''
28+
struct config{index} {{
29+
typedef config{index}_tpose_inp tpose_inp_conf;
30+
typedef config{index}_tpose_out tpose_out_conf;
31+
typedef config{index}_dense dense_conf;
32+
33+
// Layer Sizes
34+
static const unsigned n_free_data = {n_free_data};
35+
static const unsigned n_free_kernel = {n_free_kernel};
36+
static const unsigned n_contract = {n_contract};
37+
static const unsigned n_inplace = {n_inplace};
38+
39+
// Resource reuse info
40+
static const unsigned io_type = nnet::{iotype};
41+
static const unsigned strategy = nnet::{strategy};
42+
static const unsigned reuse_factor = {reuse_factor};
43+
static const unsigned parallelization_factor = {parallelization_factor}; // Only useful when n_inplace > 1
44+
static const bool store_weights_in_bram = false; // NOT USED
45+
}};
46+
'''
47+
48+
einsum_dense_function_template = 'nnet::einsum_dense<{input_t}, {output_t}, {config}>({input}, {output}, {w}, {b});'
49+
50+
einsum_dense_include_list = ['nnet_utils/nnet_einsum_dense.h', 'nnet_utils/nnet_dense.h']
51+
52+
53+
class EinsumDenseConfigTemplate(LayerConfigTemplate):
54+
def __init__(self):
55+
super().__init__(EinsumDense)
56+
self.template = einsum_dense_config_template
57+
self.dense_template = conv_dense_config_template
58+
59+
def format(self, node: EinsumDense):
60+
default_params = self._default_config_params(node)
61+
62+
strategy = node.model.config.get_strategy(node)
63+
io_type = node.model.config.get_config_value('IOType')
64+
65+
assert io_type == 'io_parallel', 'EinsumDense layer only supports io_parallel for now'
66+
assert strategy.lower() == 'latency', 'EinsumDense layer only supports Latency strategy for now'
67+
68+
# EinsumDense config
69+
params = default_params.copy()
70+
params['strategy'] = strategy
71+
params['n_free_data'] = node.attributes.attributes['n_free_data']
72+
params['n_free_kernel'] = node.attributes.attributes['n_free_kernel']
73+
params['n_contract'] = node.attributes.attributes['n_contract']
74+
params['n_inplace'] = node.attributes.attributes['n_inplace']
75+
params['parallelization_factor'] = node.attributes.attributes['parallelization_factor']
76+
77+
einsum_conf = self.template.format(**params)
78+
79+
# inp/out transpose config
80+
inp_shape = node.attributes.attributes['inp_shape']
81+
out_interpert_shape = node.attributes.attributes['out_interpert_shape']
82+
inp_tpose_idxs = node.attributes.attributes['inp_tpose_idxs']
83+
out_tpose_idxs = node.attributes.attributes['out_tpose_idxs']
84+
tpose_inp_conf_name = f'config{node.index}_tpose_inp'
85+
tpose_out_conf_name = f'config{node.index}_tpose_out'
86+
87+
inp_tpose_conf = transpose_config_gen(tpose_inp_conf_name, inp_shape, inp_tpose_idxs)
88+
out_tpose_conf = transpose_config_gen(tpose_out_conf_name, out_interpert_shape, out_tpose_idxs)
89+
90+
# Dense config
91+
dense_params = default_params.copy()
92+
dense_params['strategy'] = strategy
93+
dense_params['n_in'] = node.attributes.attributes['n_contract']
94+
dense_params['n_out'] = node.attributes.attributes['n_free_kernel']
95+
if node.attributes.attributes['n_inplace'] == 1:
96+
dense_params['nzeros'] = node.get_weights('weight').nzeros # type: ignore
97+
else:
98+
dense_params['nzeros'] = '-1; // Not making sense when kernels are switching'
99+
dense_params['product_type'] = get_backend('vivado').product_type(
100+
node.get_input_variable().type.precision, node.get_weights('weight').type.precision # type: ignore
101+
)
102+
103+
dense_params['dense_function'] = 'DenseLatency' # Latency only for now
104+
105+
dense_config = self.dense_template.format(**dense_params)
106+
107+
return '\n\n'.join((inp_tpose_conf, out_tpose_conf, dense_config, einsum_conf))
108+
109+
110+
class EinsumDenseFunctionTemplate(FunctionCallTemplate):
111+
def __init__(self):
112+
super().__init__(EinsumDense, include_header=einsum_dense_include_list)
113+
self.template = einsum_dense_function_template
114+
115+
def format(self, node):
116+
params = self._default_function_params(node)
117+
params['w'] = node.get_weights('weight').name
118+
params['b'] = node.get_weights('bias').name
119+
120+
return self.template.format(**params)

hls4ml/backends/vivado/passes/reshaping_templates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def format(self, node):
127127
transpose_function_template = 'nnet::transpose<{input_t}, {output_t}, {config_name}>({input}, {output});'
128128

129129

130-
def permute_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]):
130+
def transpose_config_gen(name: str, shape: tuple[int, ...], perm: tuple[int, ...]):
131131
new_shape = tuple(shape[i] for i in perm)
132132
strides = np.cumprod((shape[1:] + (1,))[::-1])[::-1]
133133
perm_strides = tuple(int(strides[i]) for i in perm)
@@ -151,7 +151,7 @@ def format(self, node):
151151
shape = tuple(node.get_input_variable().shape)
152152
perm = tuple(node.get_attr('perm'))
153153
name = f'config{node.index}'
154-
return permute_config_gen(name, shape, perm)
154+
return transpose_config_gen(name, shape, perm)
155155

156156

157157
class TransposeFunctionTemplate(FunctionCallTemplate):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from . import conv # noqa: F401
22
from . import core # noqa: F401
3+
from . import einsum_dense # noqa: F401
34
from ._base import registry as layer_handlers
45

56
__all__ = ['layer_handlers']
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import typing
2+
from typing import Sequence
3+
4+
from ._base import KerasV3LayerHandler, register
5+
6+
if typing.TYPE_CHECKING:
7+
import keras
8+
from keras.api import KerasTensor
9+
10+
11+
def strip_batch_dim(equation: str):
12+
"""Remove the batch dimension from the equation.
13+
14+
Args:
15+
equation (str): The einsum equation.
16+
17+
Returns:
18+
str: The einsum equation without the batch dimension.
19+
"""
20+
21+
_inps, out = equation.split('->')
22+
inp0, inp1 = _inps.split(',')
23+
if inp0.startswith('...'):
24+
assert out.startswith('...'), f'Error in eq: {equation}: Batch dim mismatch for the input and output.'
25+
else:
26+
assert inp0[0] == out[0], f'Error in eq: {equation}: Batch dim mismatch for the input and output.'
27+
assert inp0[0] not in inp1, f'Error in eq: {equation}: Batch dim is used in the kernel.'
28+
inp0, out = inp0[1:], out[1:]
29+
return f'{inp0},{inp1}->{out}'
30+
31+
32+
@register
33+
class KV3EinsumDenseHandler(KerasV3LayerHandler):
34+
handles = ('keras.src.layers.core.einsum_dense.EinsumDense',)
35+
36+
def handle(
37+
self,
38+
layer: 'keras.layers.EinsumDense',
39+
in_tensors: Sequence['KerasTensor'],
40+
out_tensors: Sequence['KerasTensor'],
41+
):
42+
import keras
43+
44+
assert len(in_tensors) == 1, 'EinsumDense layer must have exactly one input tensor'
45+
assert len(out_tensors) == 1, 'EinsumDense layer must have exactly one output tensor'
46+
47+
inp_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore
48+
out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore
49+
50+
# fmt: off
51+
assert all(d is not None for d in inp_shape), \
52+
f'Error when processing {layer.name}: EinsumDense layer requires fully inp shapes'
53+
assert all(d is not None for d in out_shape), \
54+
f'Error when processing {layer.name}: EinsumDense layer requires fully out shapes'
55+
# fmt: on
56+
57+
equation = strip_batch_dim(layer.equation)
58+
59+
kernel = keras.ops.convert_to_numpy(layer.kernel)
60+
61+
bias = None
62+
if layer.bias_axes:
63+
bias = keras.ops.convert_to_numpy(layer.bias)
64+
65+
return {
66+
'class_name': 'EinsumDense',
67+
'equation': equation,
68+
'weight_data': kernel,
69+
'bias_data': bias,
70+
'inp_shape': inp_shape,
71+
'out_shape': out_shape,
72+
}

hls4ml/model/layers.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@
2727
find_minimum_width,
2828
)
2929
from hls4ml.utils import attribute_descriptions as descriptions
30+
from hls4ml.utils.einsum_utils import parse_einsum
3031
from hls4ml.utils.string_utils import convert_to_snake_case
3132

32-
3333
# TODO move this to some utility module
34+
35+
3436
class classproperty:
3537
def __init__(self, func):
3638
self.func = func
@@ -1618,6 +1620,67 @@ def initialize(self):
16181620
self.add_output_variable([len(self.get_attr('expression'))], [f'N_OUTPUTS_{self.index}'], var_name='y')
16191621

16201622

1623+
class EinsumDense(Layer):
1624+
_expected_attributes = [
1625+
WeightAttribute('weight'),
1626+
WeightAttribute('bias'),
1627+
TypeAttribute('weight'),
1628+
TypeAttribute('bias'),
1629+
TypeAttribute('accum'),
1630+
Attribute('equation', value_type=str),
1631+
Attribute('inp_shape', value_type=tuple),
1632+
Attribute('out_shape', value_type=tuple),
1633+
]
1634+
1635+
def initialize(self):
1636+
out_shape = self.attributes['out_shape']
1637+
if len(out_shape) > 1:
1638+
dims = [f'N_LAYER_{self.index}_D{i}' for i in range(1, len(out_shape) + 1)]
1639+
else:
1640+
dims = [f'N_LAYER_{self.index}']
1641+
self.add_output_variable(list(out_shape), dims)
1642+
1643+
kernel: np.ndarray = self.attributes.attributes['weight_data']
1644+
bias: np.ndarray | None = self.attributes.attributes['bias_data']
1645+
equation = self.attributes['equation']
1646+
inp_shape = self.attributes['inp_shape']
1647+
out_shape = self.attributes['out_shape']
1648+
1649+
recipe = parse_einsum(equation, inp_shape, kernel.shape)
1650+
inp_tpose_idxs, ker_tpose_idxs = recipe['in_transpose_idxs']
1651+
out_tpose_idxs = recipe['out_transpose_idxs']
1652+
1653+
# Pre-transpose kernel (and bias) to save a transpose in cpp. Shouldn't matter for latency strategy though.
1654+
# hls4ml dense acts like i,ij->j
1655+
# parser assumes ij,j->i, so we need to transpose the kernel to match
1656+
kernel = kernel.transpose(ker_tpose_idxs)
1657+
kernel = kernel.reshape(recipe['I'], recipe['L1'], recipe['C']).transpose(0, 2, 1)
1658+
1659+
# TODO: for weight in bram mode (resource), broadcasting bias here shall be avoided.
1660+
if bias is not None:
1661+
bias = np.broadcast_to(bias, out_shape).transpose(np.argsort(out_tpose_idxs))
1662+
else:
1663+
# The automatically created bias is just the last dimension of the output shape
1664+
# Which is too small in general for einsum dense.
1665+
# The transpose is just to match the shape in case of have real bias, no real effect.
1666+
bias = np.zeros(out_shape).transpose(np.argsort(out_tpose_idxs))
1667+
1668+
self.attributes.attributes['weight_data'] = kernel
1669+
self.attributes.attributes['bias_data'] = bias
1670+
self.attributes['inp_tpose_idxs'] = inp_tpose_idxs
1671+
self.attributes['out_tpose_idxs'] = out_tpose_idxs
1672+
self.attributes['out_interpert_shape'] = recipe['out_interpert_shape']
1673+
self.attributes['n_free_data'] = recipe['L0']
1674+
self.attributes['n_free_kernel'] = recipe['L1']
1675+
self.attributes['n_inplace'] = recipe['I']
1676+
self.attributes['n_contract'] = recipe['C']
1677+
pf = self.attributes.attributes.get('parallelization_factor', recipe['L0'])
1678+
self.attributes['parallelization_factor'] = pf
1679+
1680+
self.add_weights(compression=self.model.config.get_compression(self))
1681+
self.add_bias()
1682+
1683+
16211684
layer_map = {
16221685
'Input': Input,
16231686
'InputLayer': Input,
@@ -1686,6 +1749,7 @@ def initialize(self):
16861749
'SymbolicExpression': SymbolicExpression,
16871750
# TensorFlow-specific layers:
16881751
'BiasAdd': BiasAdd,
1752+
'EinsumDense': EinsumDense,
16891753
}
16901754

16911755

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
#ifndef NNET_EINSUM_DENSE_H_
2+
#define NNET_EINSUM_DENSE_H_
3+
4+
#include "hls_stream.h"
5+
#include "nnet_common.h"
6+
#include "nnet_dense_latency.h"
7+
#include "nnet_dense_resource.h"
8+
#include "nnet_function_stubs.h"
9+
#include "nnet_helpers.h"
10+
#include "nnet_mult.h"
11+
#include "nnet_transpose.h"
12+
13+
namespace nnet {
14+
15+
struct einsum_dense_config {
16+
// Internal data type definitions
17+
18+
typedef void tpose_inp_conf;
19+
typedef void tpose_out_conf;
20+
typedef void dense_conf;
21+
22+
// Layer Sizes
23+
static const unsigned n_free_data = 1;
24+
static const unsigned n_free_kernel = 1;
25+
static const unsigned n_contract = 1;
26+
static const unsigned n_inplace = 1;
27+
28+
// Resource reuse info
29+
static const unsigned io_type = io_parallel;
30+
static const unsigned strategy = latency;
31+
static const unsigned reuse_factor = 1;
32+
static const unsigned parallelization_factor = 1000; // Only useful when n_inplace > 1
33+
static const bool store_weights_in_bram = false; // NOT USED
34+
35+
// Product function to use
36+
template <class x_T, class y_T> using product = nnet::product::mult<x_T, y_T>;
37+
};
38+
39+
template <class data_T, class res_T, typename CONFIG_T>
40+
void einsum_dense(
41+
data_T data[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace],
42+
res_T res[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace],
43+
typename CONFIG_T::dense_conf::weight_t weights[CONFIG_T::n_free_kernel * CONFIG_T::n_contract * CONFIG_T::n_inplace],
44+
typename CONFIG_T::dense_conf::bias_t biases[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace]) {
45+
data_T inp_tpose[CONFIG_T::n_free_data * CONFIG_T::n_contract * CONFIG_T::n_inplace];
46+
res_T out_tpose[CONFIG_T::n_free_data * CONFIG_T::n_free_kernel * CONFIG_T::n_inplace];
47+
res_T out_buffer[CONFIG_T::n_free_kernel];
48+
#pragma HLS ARRAY_PARTITION variable = inp_tpose complete
49+
#pragma HLS ARRAY_PARTITION variable = out_tpose complete
50+
51+
nnet::transpose<data_T, data_T, typename CONFIG_T::tpose_inp_conf>(data, inp_tpose);
52+
53+
constexpr unsigned L0 = CONFIG_T::n_free_data;
54+
constexpr unsigned L1 = CONFIG_T::n_free_kernel;
55+
constexpr unsigned C = CONFIG_T::n_contract;
56+
constexpr unsigned I = CONFIG_T::n_inplace;
57+
58+
for (unsigned l0 = 0; l0 < L0; l0++) {
59+
#pragma HLS UNROLL factor = CONFIG_T::parallelization_factor
60+
for (unsigned i = 0; i < I; i++) {
61+
#pragma HLS UNROLL
62+
// even w/o explicit distributed arithmetic optimization, latency kernels are partially implemented as such
63+
// so reusing the same multiplier for different weights doesn't really help... only full unrolling for now
64+
dense<data_T, res_T, typename CONFIG_T::dense_conf>(&inp_tpose[(i * L0 + l0) * C], out_buffer,
65+
&weights[(i * L1 * C)], &biases[((i * L0 + l0) * L1)]);
66+
for (unsigned j = 0; j < L1; j++) {
67+
#pragma HLS UNROLL
68+
out_tpose[(i * L0 + l0) * L1 + j] = out_buffer[j];
69+
}
70+
}
71+
}
72+
73+
nnet::transpose<res_T, res_T, typename CONFIG_T::tpose_out_conf>(out_tpose, res);
74+
}
75+
76+
} // namespace nnet
77+
78+
#endif

0 commit comments

Comments
 (0)