|
| 1 | +import warnings |
| 2 | + |
| 3 | +from hls4ml.model.layers import Layer, Softmax |
| 4 | +from hls4ml.model.optimizer import OptimizerPass |
| 5 | + |
| 6 | + |
| 7 | +class FixSoftmaxTableSize(OptimizerPass): |
| 8 | + def match(self, node): |
| 9 | + return isinstance(node, Softmax) |
| 10 | + |
| 11 | + def transform(self, model, node: Layer): |
| 12 | + inp_layer = node.get_input_node() # type: ignore |
| 13 | + if not isinstance(inp_layer, Layer): |
| 14 | + raise RuntimeError(f'Softmax layer {node.name} does not have an input layer') |
| 15 | + |
| 16 | + input_bw: int = inp_layer.get_attr('result_t').precision.width # type: ignore |
| 17 | + table_bw: int = node.get_attr('inv_table_t').precision.width # type: ignore |
| 18 | + table_size = int(node.get_attr('table_size')) # type: ignore |
| 19 | + |
| 20 | + backend = model.config.config['Backend'] |
| 21 | + |
| 22 | + # Somehow, Intel want one extra bits for the table. |
| 23 | + # I don't know why but if not simulation will crash with segmentation fault. |
| 24 | + backend_limitation = -1 if backend == 'Quartus' else 0 |
| 25 | + |
| 26 | + if 2 ** (min(input_bw, table_bw) + backend_limitation) < table_size: |
| 27 | + # If table size is too large w.r.t. input bitwidth and table bitwidth, |
| 28 | + # reduce table size to avoid undefined behavior when cutting indices from, |
| 29 | + # fixed point number. |
| 30 | + node.set_attr('table_size', str(2 ** (min(input_bw, table_bw) + backend_limitation))) |
| 31 | + if 2**input_bw < table_size: |
| 32 | + # The warning message does not have to be looking like this, but you are asking |
| 33 | + # 125 characters long line. |
| 34 | + warnings.warn( |
| 35 | + ( |
| 36 | + f"Softmax layer {node.name} table size is too large for input" |
| 37 | + f"bitwidth {input_bw}. Setting table size to {2**input_bw}." |
| 38 | + "To avoid this warning, please increase input bitwidth or" |
| 39 | + "decrease table size." |
| 40 | + ), |
| 41 | + stacklevel=1, |
| 42 | + ) |
| 43 | + if 2**table_bw < table_size: |
| 44 | + warnings.warn( |
| 45 | + ( |
| 46 | + f"Softmax layer {node.name} table size is too large for input" |
| 47 | + f"bitwidth {input_bw}. Setting table size to {2**input_bw}." |
| 48 | + "To avoid this warning, please increase input bitwidth or" |
| 49 | + "decrease table size." |
| 50 | + ), |
| 51 | + stacklevel=1, |
| 52 | + ) |
| 53 | + if backend == 'Quartus': |
| 54 | + warnings.warn( |
| 55 | + ( |
| 56 | + "Quartus backend's table size is half of 2^min(input_bw-1,table_bw-1)" |
| 57 | + " instead of 2^min(input_bw,table_bw)." |
| 58 | + ), |
| 59 | + stacklevel=1, |
| 60 | + ) |
| 61 | + return False |
| 62 | + |
| 63 | + |
| 64 | +def register_softmax__table_size_fix(backend): |
| 65 | + backend.register_pass('fix_softmax_table_size', FixSoftmaxTableSize) |
0 commit comments