Skip to content

Commit 5d8237f

Browse files
committed
add pytest
1 parent 0f103d4 commit 5d8237f

File tree

1 file changed

+311
-0
lines changed

1 file changed

+311
-0
lines changed

test/pytest/test_max_precision.py

Lines changed: 311 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,311 @@
1+
from collections import namedtuple
2+
3+
import pytest
4+
5+
from hls4ml.model.optimizer.passes.infer_precision import InferPrecisionTypes
6+
from hls4ml.model.types import (
7+
FixedPrecisionType,
8+
IntegerPrecisionType,
9+
NamedType,
10+
RoundingMode,
11+
SaturationMode,
12+
UnspecifiedPrecisionType,
13+
)
14+
15+
16+
class MockBackend:
17+
def convert_precision_string(self, precision_string):
18+
"""
19+
Simple mock that expects a FixedPrecisionType object or None
20+
to be passed directly for testing purposes, or a simple string parser.
21+
"""
22+
if isinstance(precision_string, (FixedPrecisionType, IntegerPrecisionType)):
23+
return precision_string
24+
return None
25+
26+
27+
class MockConfig:
28+
def __init__(self, max_precision=None, default_precision=None):
29+
self.model_precision = {}
30+
if max_precision:
31+
self.model_precision['maximum'] = max_precision
32+
if default_precision:
33+
self.model_precision['default'] = default_precision
34+
35+
self.backend = MockBackend()
36+
37+
38+
class MockModel:
39+
def __init__(self, max_precision=None):
40+
default = FixedPrecisionType(width=16, integer=6)
41+
self.config = MockConfig(max_precision, default)
42+
43+
44+
class MockVariable:
45+
def __init__(self, precision):
46+
self.type = namedtuple('Type', ['precision'])(precision)
47+
self.shape = [10, 10]
48+
49+
50+
class MockWeight:
51+
def __init__(self, precision):
52+
self.precision = precision
53+
self.nonzeros = 10
54+
55+
def update_precision(self, new_precision):
56+
self.precision = new_precision
57+
58+
59+
class MockNode:
60+
def __init__(self, class_name, name='test_node', max_precision=None, inputs=None):
61+
self.class_name = class_name
62+
self.name = name
63+
self.model = MockModel(max_precision)
64+
self.attributes = {
65+
'n_in': 10,
66+
'n_out': 10,
67+
'n_chan': 3,
68+
'filt_height': 3,
69+
'filt_width': 3,
70+
'pool_height': 2,
71+
'pool_width': 2,
72+
'op': 'multiply', # Default for merge tests
73+
'pool_op': 'average',
74+
}
75+
self.types = {
76+
'result_t': NamedType('result_t', UnspecifiedPrecisionType()),
77+
'accum_t': NamedType('accum_t', UnspecifiedPrecisionType()),
78+
'weight_t': NamedType('weight_t', FixedPrecisionType(8, 4)),
79+
'bias_t': NamedType('bias_t', FixedPrecisionType(8, 4)),
80+
'scale_t': NamedType('scale_t', FixedPrecisionType(8, 4)),
81+
'pointwise_t': NamedType('pointwise_t', FixedPrecisionType(8, 4)),
82+
}
83+
self.weights = {
84+
'weight': MockWeight(FixedPrecisionType(8, 4)),
85+
'bias': MockWeight(FixedPrecisionType(8, 4)),
86+
'scale': MockWeight(FixedPrecisionType(8, 4)),
87+
'pointwise': MockWeight(FixedPrecisionType(8, 4)),
88+
}
89+
90+
# Setup inputs
91+
self.inputs = inputs if inputs else ['input_1']
92+
self._input_vars = {'input_1': MockVariable(FixedPrecisionType(16, 6))}
93+
if len(self.inputs) > 1:
94+
self._input_vars['input_2'] = MockVariable(FixedPrecisionType(16, 6))
95+
96+
def get_attr(self, key, default=None):
97+
return self.attributes.get(key, default)
98+
99+
def get_input_variable(self, input_name=None):
100+
if input_name is None:
101+
return self._input_vars[self.inputs[0]]
102+
return self._input_vars.get(input_name)
103+
104+
def get_output_variable(self):
105+
return MockVariable(UnspecifiedPrecisionType())
106+
107+
108+
@pytest.fixture
109+
def optimizer():
110+
return InferPrecisionTypes()
111+
112+
113+
class TestApplyMaxPrecisionConstraints:
114+
"""
115+
Tests the logic of _apply_max_precision_constraints function directly.
116+
"""
117+
118+
def test_no_max_precision_set(self, optimizer):
119+
"""If 'maximum' is not in config, return precision unchanged."""
120+
node = MockNode('Dense', max_precision=None)
121+
122+
input_prec = FixedPrecisionType(width=20, integer=10)
123+
result = optimizer._apply_max_precision_constraints(node, input_prec)
124+
125+
assert result.width == 20
126+
assert result.integer == 10
127+
128+
def test_clamp_width(self, optimizer):
129+
"""Should reduce width if input > max."""
130+
max_prec = FixedPrecisionType(width=16, integer=10)
131+
node = MockNode('Dense', max_precision=max_prec)
132+
133+
input_prec = FixedPrecisionType(width=32, integer=10)
134+
result = optimizer._apply_max_precision_constraints(node, input_prec)
135+
136+
assert result.width == 16
137+
assert result.integer == 10
138+
139+
def test_clamp_integer(self, optimizer):
140+
"""Should reduce integer bits if input > max."""
141+
max_prec = FixedPrecisionType(width=32, integer=5)
142+
node = MockNode('Dense', max_precision=max_prec)
143+
144+
input_prec = FixedPrecisionType(width=32, integer=10)
145+
result = optimizer._apply_max_precision_constraints(node, input_prec)
146+
147+
assert result.width == 32
148+
assert result.integer == 5
149+
150+
def test_signedness_inheritance(self, optimizer):
151+
"""Should always adopt the signedness of the maximum precision."""
152+
# Max is Unsigned (signed=0)
153+
max_prec = FixedPrecisionType(width=32, integer=10, signed=0)
154+
node = MockNode('Dense', max_precision=max_prec)
155+
156+
# Input is Signed
157+
input_prec = FixedPrecisionType(width=32, integer=10, signed=1)
158+
result = optimizer._apply_max_precision_constraints(node, input_prec)
159+
160+
assert result.signed == 0
161+
162+
def test_mode_inheritance_from_max(self, optimizer):
163+
"""If Max specifies rounding/sat modes, they should override input."""
164+
max_prec = FixedPrecisionType(
165+
16, 6, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT, saturation_bits=2
166+
)
167+
node = MockNode('Dense', max_precision=max_prec)
168+
169+
# Input has different modes
170+
input_prec = FixedPrecisionType(16, 6, rounding_mode=RoundingMode.TRN, saturation_mode=SaturationMode.WRAP)
171+
172+
result = optimizer._apply_max_precision_constraints(node, input_prec)
173+
174+
assert result.rounding_mode == RoundingMode.RND
175+
assert result.saturation_mode == SaturationMode.SAT
176+
assert result.saturation_bits == 2
177+
178+
def test_mode_preservation_when_max_is_none(self, optimizer):
179+
"""If Max modes are default, input modes should be preserved."""
180+
# Create a max precision where modes are initialized with defaults
181+
max_prec = FixedPrecisionType(16, 6)
182+
183+
node = MockNode('Dense', max_precision=max_prec)
184+
185+
input_prec = FixedPrecisionType(16, 6, rounding_mode=RoundingMode.RND_ZERO, saturation_mode=SaturationMode.SAT_SYM)
186+
187+
result = optimizer._apply_max_precision_constraints(node, input_prec)
188+
189+
assert result.rounding_mode == RoundingMode.RND_ZERO
190+
assert result.saturation_mode == SaturationMode.SAT_SYM
191+
192+
193+
class TestInferPrecision:
194+
"""
195+
Tests that _infer_precision calls apply_max_constraints for specific layers.
196+
We verify this by setting a strict Max constraint and asserting the result_t
197+
complies with it.
198+
"""
199+
200+
# Define a strict constraint
201+
STRICT_MAX = FixedPrecisionType(width=4, integer=2, signed=True)
202+
203+
@pytest.mark.parametrize(
204+
'layer_class',
205+
[
206+
'Dense',
207+
'Conv1D',
208+
'Conv2D',
209+
'PointwiseConv2D',
210+
'DepthwiseConv2D',
211+
],
212+
)
213+
def test_common_precision_layers(self, optimizer, layer_class):
214+
"""Tests layers that use _infer_common_precision."""
215+
node = MockNode(layer_class, max_precision=self.STRICT_MAX)
216+
217+
node._input_vars['input_1'] = MockVariable(FixedPrecisionType(32, 16, signed=1))
218+
219+
types_to_infer = ['result_t', 'accum_t']
220+
optimizer._infer_precision(node, types_to_infer)
221+
222+
res_prec = node.types['result_t'].precision
223+
assert res_prec.width == 4
224+
assert res_prec.integer == 2
225+
226+
def test_batch_normalization(self, optimizer):
227+
"""Tests BN layer inference."""
228+
node = MockNode('BatchNormalization', max_precision=self.STRICT_MAX)
229+
node._input_vars['input_1'] = MockVariable(FixedPrecisionType(32, 16))
230+
231+
types_to_infer = ['result_t']
232+
optimizer._infer_precision(node, types_to_infer)
233+
234+
res_prec = node.types['result_t'].precision
235+
assert res_prec.width == 4
236+
assert res_prec.integer == 2
237+
238+
def test_merge_multiply(self, optimizer):
239+
"""Tests Merge layer with Multiply op."""
240+
node = MockNode('Merge', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2'])
241+
node.attributes['op'] = 'multiply'
242+
243+
node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10))
244+
node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10))
245+
246+
types_to_infer = ['result_t']
247+
optimizer._infer_precision(node, types_to_infer)
248+
249+
res_prec = node.types['result_t'].precision
250+
assert res_prec.width == 4
251+
assert res_prec.integer == 2
252+
253+
def test_merge_add(self, optimizer):
254+
"""Tests Merge layer with Add op."""
255+
node = MockNode('Merge', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2'])
256+
node.attributes['op'] = 'add'
257+
258+
node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10))
259+
node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10))
260+
261+
types_to_infer = ['result_t']
262+
optimizer._infer_precision(node, types_to_infer)
263+
264+
res_prec = node.types['result_t'].precision
265+
assert res_prec.width == 4
266+
assert res_prec.integer == 2
267+
268+
def test_concatenate_same_input_precisions(self, optimizer):
269+
"""
270+
Tests Concatenate layer. If precisions of both inputs are the same,
271+
max precision is ignored (see _infer_cat_precision function).
272+
"""
273+
node = MockNode('Concatenate', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2'])
274+
275+
node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10))
276+
node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10))
277+
278+
types_to_infer = ['result_t']
279+
optimizer._infer_precision(node, types_to_infer)
280+
281+
res_prec = node.types['result_t'].precision
282+
assert res_prec.width == 20
283+
assert res_prec.integer == 10
284+
285+
def test_concatenate_different_input_precisions(self, optimizer):
286+
"""Tests Concatenate layer."""
287+
node = MockNode('Concatenate', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2'])
288+
289+
node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10))
290+
node._input_vars['input_2'] = MockVariable(FixedPrecisionType(16, 6))
291+
292+
types_to_infer = ['result_t']
293+
optimizer._infer_precision(node, types_to_infer)
294+
295+
res_prec = node.types['result_t'].precision
296+
assert res_prec.width == 4
297+
assert res_prec.integer == 2
298+
299+
def test_dot(self, optimizer):
300+
"""Tests Dot layer."""
301+
node = MockNode('Dot', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2'])
302+
303+
node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10))
304+
node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10))
305+
306+
types_to_infer = ['result_t']
307+
optimizer._infer_precision(node, types_to_infer)
308+
309+
res_prec = node.types['result_t'].precision
310+
assert res_prec.width == 4
311+
assert res_prec.integer == 2

0 commit comments

Comments
 (0)