|
| 1 | +from collections import namedtuple |
| 2 | + |
| 3 | +import pytest |
| 4 | + |
| 5 | +from hls4ml.model.optimizer.passes.infer_precision import InferPrecisionTypes |
| 6 | +from hls4ml.model.types import ( |
| 7 | + FixedPrecisionType, |
| 8 | + IntegerPrecisionType, |
| 9 | + NamedType, |
| 10 | + RoundingMode, |
| 11 | + SaturationMode, |
| 12 | + UnspecifiedPrecisionType, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +class MockBackend: |
| 17 | + def convert_precision_string(self, precision_string): |
| 18 | + """ |
| 19 | + Simple mock that expects a FixedPrecisionType object or None |
| 20 | + to be passed directly for testing purposes, or a simple string parser. |
| 21 | + """ |
| 22 | + if isinstance(precision_string, (FixedPrecisionType, IntegerPrecisionType)): |
| 23 | + return precision_string |
| 24 | + return None |
| 25 | + |
| 26 | + |
| 27 | +class MockConfig: |
| 28 | + def __init__(self, max_precision=None, default_precision=None): |
| 29 | + self.model_precision = {} |
| 30 | + if max_precision: |
| 31 | + self.model_precision['maximum'] = max_precision |
| 32 | + if default_precision: |
| 33 | + self.model_precision['default'] = default_precision |
| 34 | + |
| 35 | + self.backend = MockBackend() |
| 36 | + |
| 37 | + |
| 38 | +class MockModel: |
| 39 | + def __init__(self, max_precision=None): |
| 40 | + default = FixedPrecisionType(width=16, integer=6) |
| 41 | + self.config = MockConfig(max_precision, default) |
| 42 | + |
| 43 | + |
| 44 | +class MockVariable: |
| 45 | + def __init__(self, precision): |
| 46 | + self.type = namedtuple('Type', ['precision'])(precision) |
| 47 | + self.shape = [10, 10] |
| 48 | + |
| 49 | + |
| 50 | +class MockWeight: |
| 51 | + def __init__(self, precision): |
| 52 | + self.precision = precision |
| 53 | + self.nonzeros = 10 |
| 54 | + |
| 55 | + def update_precision(self, new_precision): |
| 56 | + self.precision = new_precision |
| 57 | + |
| 58 | + |
| 59 | +class MockNode: |
| 60 | + def __init__(self, class_name, name='test_node', max_precision=None, inputs=None): |
| 61 | + self.class_name = class_name |
| 62 | + self.name = name |
| 63 | + self.model = MockModel(max_precision) |
| 64 | + self.attributes = { |
| 65 | + 'n_in': 10, |
| 66 | + 'n_out': 10, |
| 67 | + 'n_chan': 3, |
| 68 | + 'filt_height': 3, |
| 69 | + 'filt_width': 3, |
| 70 | + 'pool_height': 2, |
| 71 | + 'pool_width': 2, |
| 72 | + 'op': 'multiply', # Default for merge tests |
| 73 | + 'pool_op': 'average', |
| 74 | + } |
| 75 | + self.types = { |
| 76 | + 'result_t': NamedType('result_t', UnspecifiedPrecisionType()), |
| 77 | + 'accum_t': NamedType('accum_t', UnspecifiedPrecisionType()), |
| 78 | + 'weight_t': NamedType('weight_t', FixedPrecisionType(8, 4)), |
| 79 | + 'bias_t': NamedType('bias_t', FixedPrecisionType(8, 4)), |
| 80 | + 'scale_t': NamedType('scale_t', FixedPrecisionType(8, 4)), |
| 81 | + 'pointwise_t': NamedType('pointwise_t', FixedPrecisionType(8, 4)), |
| 82 | + } |
| 83 | + self.weights = { |
| 84 | + 'weight': MockWeight(FixedPrecisionType(8, 4)), |
| 85 | + 'bias': MockWeight(FixedPrecisionType(8, 4)), |
| 86 | + 'scale': MockWeight(FixedPrecisionType(8, 4)), |
| 87 | + 'pointwise': MockWeight(FixedPrecisionType(8, 4)), |
| 88 | + } |
| 89 | + |
| 90 | + # Setup inputs |
| 91 | + self.inputs = inputs if inputs else ['input_1'] |
| 92 | + self._input_vars = {'input_1': MockVariable(FixedPrecisionType(16, 6))} |
| 93 | + if len(self.inputs) > 1: |
| 94 | + self._input_vars['input_2'] = MockVariable(FixedPrecisionType(16, 6)) |
| 95 | + |
| 96 | + def get_attr(self, key, default=None): |
| 97 | + return self.attributes.get(key, default) |
| 98 | + |
| 99 | + def get_input_variable(self, input_name=None): |
| 100 | + if input_name is None: |
| 101 | + return self._input_vars[self.inputs[0]] |
| 102 | + return self._input_vars.get(input_name) |
| 103 | + |
| 104 | + def get_output_variable(self): |
| 105 | + return MockVariable(UnspecifiedPrecisionType()) |
| 106 | + |
| 107 | + |
| 108 | +@pytest.fixture |
| 109 | +def optimizer(): |
| 110 | + return InferPrecisionTypes() |
| 111 | + |
| 112 | + |
| 113 | +class TestApplyMaxPrecisionConstraints: |
| 114 | + """ |
| 115 | + Tests the logic of _apply_max_precision_constraints function directly. |
| 116 | + """ |
| 117 | + |
| 118 | + def test_no_max_precision_set(self, optimizer): |
| 119 | + """If 'maximum' is not in config, return precision unchanged.""" |
| 120 | + node = MockNode('Dense', max_precision=None) |
| 121 | + |
| 122 | + input_prec = FixedPrecisionType(width=20, integer=10) |
| 123 | + result = optimizer._apply_max_precision_constraints(node, input_prec) |
| 124 | + |
| 125 | + assert result.width == 20 |
| 126 | + assert result.integer == 10 |
| 127 | + |
| 128 | + def test_clamp_width(self, optimizer): |
| 129 | + """Should reduce width if input > max.""" |
| 130 | + max_prec = FixedPrecisionType(width=16, integer=10) |
| 131 | + node = MockNode('Dense', max_precision=max_prec) |
| 132 | + |
| 133 | + input_prec = FixedPrecisionType(width=32, integer=10) |
| 134 | + result = optimizer._apply_max_precision_constraints(node, input_prec) |
| 135 | + |
| 136 | + assert result.width == 16 |
| 137 | + assert result.integer == 10 |
| 138 | + |
| 139 | + def test_clamp_integer(self, optimizer): |
| 140 | + """Should reduce integer bits if input > max.""" |
| 141 | + max_prec = FixedPrecisionType(width=32, integer=5) |
| 142 | + node = MockNode('Dense', max_precision=max_prec) |
| 143 | + |
| 144 | + input_prec = FixedPrecisionType(width=32, integer=10) |
| 145 | + result = optimizer._apply_max_precision_constraints(node, input_prec) |
| 146 | + |
| 147 | + assert result.width == 32 |
| 148 | + assert result.integer == 5 |
| 149 | + |
| 150 | + def test_signedness_inheritance(self, optimizer): |
| 151 | + """Should always adopt the signedness of the maximum precision.""" |
| 152 | + # Max is Unsigned (signed=0) |
| 153 | + max_prec = FixedPrecisionType(width=32, integer=10, signed=0) |
| 154 | + node = MockNode('Dense', max_precision=max_prec) |
| 155 | + |
| 156 | + # Input is Signed |
| 157 | + input_prec = FixedPrecisionType(width=32, integer=10, signed=1) |
| 158 | + result = optimizer._apply_max_precision_constraints(node, input_prec) |
| 159 | + |
| 160 | + assert result.signed == 0 |
| 161 | + |
| 162 | + def test_mode_inheritance_from_max(self, optimizer): |
| 163 | + """If Max specifies rounding/sat modes, they should override input.""" |
| 164 | + max_prec = FixedPrecisionType( |
| 165 | + 16, 6, rounding_mode=RoundingMode.RND, saturation_mode=SaturationMode.SAT, saturation_bits=2 |
| 166 | + ) |
| 167 | + node = MockNode('Dense', max_precision=max_prec) |
| 168 | + |
| 169 | + # Input has different modes |
| 170 | + input_prec = FixedPrecisionType(16, 6, rounding_mode=RoundingMode.TRN, saturation_mode=SaturationMode.WRAP) |
| 171 | + |
| 172 | + result = optimizer._apply_max_precision_constraints(node, input_prec) |
| 173 | + |
| 174 | + assert result.rounding_mode == RoundingMode.RND |
| 175 | + assert result.saturation_mode == SaturationMode.SAT |
| 176 | + assert result.saturation_bits == 2 |
| 177 | + |
| 178 | + def test_mode_preservation_when_max_is_none(self, optimizer): |
| 179 | + """If Max modes are default, input modes should be preserved.""" |
| 180 | + # Create a max precision where modes are initialized with defaults |
| 181 | + max_prec = FixedPrecisionType(16, 6) |
| 182 | + |
| 183 | + node = MockNode('Dense', max_precision=max_prec) |
| 184 | + |
| 185 | + input_prec = FixedPrecisionType(16, 6, rounding_mode=RoundingMode.RND_ZERO, saturation_mode=SaturationMode.SAT_SYM) |
| 186 | + |
| 187 | + result = optimizer._apply_max_precision_constraints(node, input_prec) |
| 188 | + |
| 189 | + assert result.rounding_mode == RoundingMode.RND_ZERO |
| 190 | + assert result.saturation_mode == SaturationMode.SAT_SYM |
| 191 | + |
| 192 | + |
| 193 | +class TestInferPrecision: |
| 194 | + """ |
| 195 | + Tests that _infer_precision calls apply_max_constraints for specific layers. |
| 196 | + We verify this by setting a strict Max constraint and asserting the result_t |
| 197 | + complies with it. |
| 198 | + """ |
| 199 | + |
| 200 | + # Define a strict constraint |
| 201 | + STRICT_MAX = FixedPrecisionType(width=4, integer=2, signed=True) |
| 202 | + |
| 203 | + @pytest.mark.parametrize( |
| 204 | + 'layer_class', |
| 205 | + [ |
| 206 | + 'Dense', |
| 207 | + 'Conv1D', |
| 208 | + 'Conv2D', |
| 209 | + 'PointwiseConv2D', |
| 210 | + 'DepthwiseConv2D', |
| 211 | + ], |
| 212 | + ) |
| 213 | + def test_common_precision_layers(self, optimizer, layer_class): |
| 214 | + """Tests layers that use _infer_common_precision.""" |
| 215 | + node = MockNode(layer_class, max_precision=self.STRICT_MAX) |
| 216 | + |
| 217 | + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(32, 16, signed=1)) |
| 218 | + |
| 219 | + types_to_infer = ['result_t', 'accum_t'] |
| 220 | + optimizer._infer_precision(node, types_to_infer) |
| 221 | + |
| 222 | + res_prec = node.types['result_t'].precision |
| 223 | + assert res_prec.width == 4 |
| 224 | + assert res_prec.integer == 2 |
| 225 | + |
| 226 | + def test_batch_normalization(self, optimizer): |
| 227 | + """Tests BN layer inference.""" |
| 228 | + node = MockNode('BatchNormalization', max_precision=self.STRICT_MAX) |
| 229 | + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(32, 16)) |
| 230 | + |
| 231 | + types_to_infer = ['result_t'] |
| 232 | + optimizer._infer_precision(node, types_to_infer) |
| 233 | + |
| 234 | + res_prec = node.types['result_t'].precision |
| 235 | + assert res_prec.width == 4 |
| 236 | + assert res_prec.integer == 2 |
| 237 | + |
| 238 | + def test_merge_multiply(self, optimizer): |
| 239 | + """Tests Merge layer with Multiply op.""" |
| 240 | + node = MockNode('Merge', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) |
| 241 | + node.attributes['op'] = 'multiply' |
| 242 | + |
| 243 | + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) |
| 244 | + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10)) |
| 245 | + |
| 246 | + types_to_infer = ['result_t'] |
| 247 | + optimizer._infer_precision(node, types_to_infer) |
| 248 | + |
| 249 | + res_prec = node.types['result_t'].precision |
| 250 | + assert res_prec.width == 4 |
| 251 | + assert res_prec.integer == 2 |
| 252 | + |
| 253 | + def test_merge_add(self, optimizer): |
| 254 | + """Tests Merge layer with Add op.""" |
| 255 | + node = MockNode('Merge', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) |
| 256 | + node.attributes['op'] = 'add' |
| 257 | + |
| 258 | + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) |
| 259 | + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10)) |
| 260 | + |
| 261 | + types_to_infer = ['result_t'] |
| 262 | + optimizer._infer_precision(node, types_to_infer) |
| 263 | + |
| 264 | + res_prec = node.types['result_t'].precision |
| 265 | + assert res_prec.width == 4 |
| 266 | + assert res_prec.integer == 2 |
| 267 | + |
| 268 | + def test_concatenate_same_input_precisions(self, optimizer): |
| 269 | + """ |
| 270 | + Tests Concatenate layer. If precisions of both inputs are the same, |
| 271 | + max precision is ignored (see _infer_cat_precision function). |
| 272 | + """ |
| 273 | + node = MockNode('Concatenate', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) |
| 274 | + |
| 275 | + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) |
| 276 | + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10)) |
| 277 | + |
| 278 | + types_to_infer = ['result_t'] |
| 279 | + optimizer._infer_precision(node, types_to_infer) |
| 280 | + |
| 281 | + res_prec = node.types['result_t'].precision |
| 282 | + assert res_prec.width == 20 |
| 283 | + assert res_prec.integer == 10 |
| 284 | + |
| 285 | + def test_concatenate_different_input_precisions(self, optimizer): |
| 286 | + """Tests Concatenate layer.""" |
| 287 | + node = MockNode('Concatenate', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) |
| 288 | + |
| 289 | + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) |
| 290 | + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(16, 6)) |
| 291 | + |
| 292 | + types_to_infer = ['result_t'] |
| 293 | + optimizer._infer_precision(node, types_to_infer) |
| 294 | + |
| 295 | + res_prec = node.types['result_t'].precision |
| 296 | + assert res_prec.width == 4 |
| 297 | + assert res_prec.integer == 2 |
| 298 | + |
| 299 | + def test_dot(self, optimizer): |
| 300 | + """Tests Dot layer.""" |
| 301 | + node = MockNode('Dot', max_precision=self.STRICT_MAX, inputs=['input_1', 'input_2']) |
| 302 | + |
| 303 | + node._input_vars['input_1'] = MockVariable(FixedPrecisionType(20, 10)) |
| 304 | + node._input_vars['input_2'] = MockVariable(FixedPrecisionType(20, 10)) |
| 305 | + |
| 306 | + types_to_infer = ['result_t'] |
| 307 | + optimizer._infer_precision(node, types_to_infer) |
| 308 | + |
| 309 | + res_prec = node.types['result_t'].precision |
| 310 | + assert res_prec.width == 4 |
| 311 | + assert res_prec.integer == 2 |
0 commit comments