|
14 | 14 | typedef {accum_t.name} accum_t; |
15 | 15 | typedef {bias_t.name} bias_t; |
16 | 16 | typedef {weight_t.name} weight_t; |
17 | | - typedef {out_t}:: value_type out_t; |
| 17 | + typedef {out_t} out_t; |
18 | 18 | template<class x_T, class y_T> |
19 | 19 | using product = nnet::product::{product_type}<x_T, y_T>; |
20 | 20 | }};\n""" |
@@ -68,6 +68,8 @@ def format(self, node): |
68 | 68 | mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width') |
69 | 69 | mult_params['n_out'] = node.get_attr('n_filt') |
70 | 70 | mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) |
| 71 | + mult_params['merged_relu'] = "true" if node.get_merged_relu() else "false" |
| 72 | + mult_params['out_t'] = node.get_output_variable().type.name |
71 | 73 | mult_config = self.mult_template.format(**mult_params) |
72 | 74 |
|
73 | 75 | return mult_config + '\n' + conv_config |
@@ -142,7 +144,14 @@ def format(self, node): |
142 | 144 | mult_params['n_out'] = node.get_attr('n_filt') |
143 | 145 | mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision) |
144 | 146 | mult_params['merged_relu'] = "true" if node.get_merged_relu() else "false" |
145 | | - mult_params['out_t'] = node.intermediate_op.type.name |
| 147 | + print(f"My out_t Class = {type(node.intermediate_op.type)}") |
| 148 | + # TODO: Need to figure out when to append ::value_type (when |
| 149 | + # node.intermediate_op's type is nnet::array but how to get that from a |
| 150 | + # layer class?) and when not to Try: I think only io_stream IOType uses |
| 151 | + # PackedType (io_parallel does not). Could grab IOType from layer |
| 152 | + # class?? Turns out this isn't all that's needed--unclear what else. |
| 153 | + # Also might need to add relu merge into dense_latency.h |
| 154 | + mult_params['out_t'] = node.intermediate_op.type.name + '::value_type' if node.model.config.get_config_value('IOType') == 'io_stream' else node.intermediate_op.type.name |
146 | 155 | mult_config = self.mult_template.format(**mult_params) |
147 | 156 |
|
148 | 157 | return mult_config + '\n' + conv_config |
|
0 commit comments