Skip to content

Commit 640fa0c

Browse files
committed
add v3 conv handlers
1 parent 79a13ef commit 640fa0c

File tree

2 files changed

+123
-0
lines changed

2 files changed

+123
-0
lines changed

hls4ml/converters/keras_v3/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import conv # noqa: F401
12
from . import core # noqa: F401
23
from ._base import registry as layer_handlers
34

hls4ml/converters/keras_v3/conv.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import typing
2+
from math import ceil
3+
from typing import Sequence
4+
5+
import numpy as np
6+
7+
from ._base import KerasV3LayerHandler, register
8+
9+
if typing.TYPE_CHECKING:
10+
import keras
11+
from keras.api import KerasTensor
12+
13+
14+
@register
15+
class KV3ConvHandler(KerasV3LayerHandler):
16+
handles = (
17+
'keras.src.layers.convolutional.conv1d.Conv1D',
18+
'keras.src.layers.convolutional.conv2d.Conv2D',
19+
'keras.src.layers.convolutional.depthwise_conv1d.DepthwiseConv1D',
20+
'keras.src.layers.convolutional.depthwise_conv2d.DepthwiseConv2D',
21+
'keras.src.layers.convolutional.separable_conv1d.SeparableConv1D',
22+
'keras.src.layers.convolutional.separable_conv2d.SeparableConv2D',
23+
)
24+
25+
def handle(
26+
self,
27+
layer: 'keras.layers.Conv1D|keras.layers.Conv2D|keras.layers.DepthwiseConv1D|keras.layers.DepthwiseConv2D',
28+
in_tensors: Sequence['KerasTensor'],
29+
out_tensors: Sequence['KerasTensor'],
30+
):
31+
from keras.src.layers.convolutional.base_conv import BaseConv
32+
from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv
33+
from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv
34+
35+
assert len(in_tensors) == 1, f"Layer {layer.name} has more than one input"
36+
assert len(out_tensors) == 1, f"Layer {layer.name} has more than one output"
37+
38+
in_shape: tuple[int, ...] = in_tensors[0].shape[1:] # type: ignore
39+
out_shape: tuple[int, ...] = out_tensors[0].shape[1:] # type: ignore
40+
assert all(isinstance(x, int) for x in in_shape), f"Layer {layer.name} has non-fixed size input: {in_shape}"
41+
assert all(isinstance(x, int) for x in out_shape), f"Layer {layer.name} has non-fixed size output: {out_shape}"
42+
43+
kernel = np.array(layer.kernel)
44+
if layer.use_bias:
45+
bias = np.array(layer.bias)
46+
else:
47+
bias = None
48+
49+
ker_px_shape: tuple[int, ...] = layer.kernel_size
50+
data_format = layer.data_format
51+
52+
if data_format == 'channels_last':
53+
*px_in_shape, ch_in = in_shape
54+
*px_out_shape, ch_out = out_shape
55+
else:
56+
ch_in, *px_in_shape = in_shape
57+
ch_out, *px_out_shape = out_shape
58+
59+
if layer.padding == 'same':
60+
n_padding = [ceil(N / n) * n - N for N, n in zip(px_in_shape, ker_px_shape)]
61+
n_padding0 = [p // 2 for p in n_padding]
62+
n_padding1 = [p - p0 for p, p0 in zip(n_padding, n_padding0)]
63+
elif layer.padding == 'valid':
64+
n_padding0 = [0] * len(px_in_shape)
65+
n_padding1 = [0] * len(px_in_shape)
66+
elif layer.padding == 'causal':
67+
n_padding0 = [ker_px_shape[0] - 1] + [0] * (len(px_in_shape) - 1)
68+
n_padding1 = [0] * len(px_in_shape)
69+
else:
70+
raise ValueError(f"Invalid padding mode {layer.padding} for layer {layer.name}")
71+
72+
config = {
73+
'bias_data': bias,
74+
'data_format': data_format,
75+
'weight_data': kernel,
76+
'bias_data': bias,
77+
'n_filt': ch_out,
78+
'n_chan': ch_in,
79+
}
80+
81+
if layer.rank == 1:
82+
config.update(
83+
{
84+
'filt_width': ker_px_shape[0],
85+
'stride_width': layer.strides[0],
86+
'pad_left': n_padding0[0],
87+
'pad_right': n_padding1[0],
88+
'in_width': px_in_shape[0],
89+
'out_width': px_out_shape[0],
90+
}
91+
)
92+
elif layer.rank == 2:
93+
config.update(
94+
{
95+
'filt_height': ker_px_shape[0],
96+
'filt_width': ker_px_shape[1],
97+
'stride_height': layer.strides[0],
98+
'stride_width': layer.strides[1],
99+
'pad_top': n_padding0[0],
100+
'pad_bottom': n_padding1[0],
101+
'pad_left': n_padding0[1],
102+
'pad_right': n_padding1[1],
103+
'in_height': px_in_shape[0],
104+
'in_width': px_in_shape[1],
105+
'out_height': px_out_shape[0],
106+
'out_width': px_out_shape[1],
107+
}
108+
)
109+
else:
110+
_cls = f"{layer.__class__.__module__}.{layer.__class__.__name__}"
111+
raise ValueError(f"Only 1D and 2D conv layers are supported, got {_cls} (rank={layer.rank})")
112+
if isinstance(layer, BaseDepthwiseConv):
113+
config['depthwise_data'] = kernel
114+
config['depth_multiplier'] = layer.depth_multiplier
115+
elif isinstance(layer, BaseSeparableConv):
116+
config['depthwise_data'] = kernel
117+
config['pointwise_data'] = np.array(layer.pointwise_kernel)
118+
config['depth_multiplier'] = layer.depth_multiplier
119+
elif isinstance(layer, BaseConv):
120+
config['weight_data'] = kernel
121+
122+
return config

0 commit comments

Comments
 (0)