Skip to content

Commit 347a6bd

Browse files
committed
WIP trying to resolve out_t issues with the mult configs
1 parent 979aed3 commit 347a6bd

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

hls4ml/backends/vivado/passes/convolution_templates.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
typedef {accum_t.name} accum_t;
1515
typedef {bias_t.name} bias_t;
1616
typedef {weight_t.name} weight_t;
17-
typedef {out_t}:: value_type out_t;
17+
typedef {out_t} out_t;
1818
template<class x_T, class y_T>
1919
using product = nnet::product::{product_type}<x_T, y_T>;
2020
}};\n"""
@@ -68,6 +68,8 @@ def format(self, node):
6868
mult_params['n_in'] = node.get_attr('n_chan') * node.get_attr('filt_width')
6969
mult_params['n_out'] = node.get_attr('n_filt')
7070
mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
71+
mult_params['merged_relu'] = "true" if node.get_merged_relu() else "false"
72+
mult_params['out_t'] = node.get_output_variable().type.name
7173
mult_config = self.mult_template.format(**mult_params)
7274

7375
return mult_config + '\n' + conv_config
@@ -142,7 +144,14 @@ def format(self, node):
142144
mult_params['n_out'] = node.get_attr('n_filt')
143145
mult_params['product_type'] = get_backend('vivado').product_type(node.get_input_variable().type.precision, node.get_weights('weight').type.precision)
144146
mult_params['merged_relu'] = "true" if node.get_merged_relu() else "false"
145-
mult_params['out_t'] = node.intermediate_op.type.name
147+
print(f"My out_t Class = {type(node.intermediate_op.type)}")
148+
# TODO: Need to figure out when to append ::value_type (when
149+
# node.intermediate_op's type is nnet::array but how to get that from a
150+
# layer class?) and when not to Try: I think only io_stream IOType uses
151+
# PackedType (io_parallel does not). Could grab IOType from layer
152+
# class?? Turns out this isn't all that's needed--unclear what else.
153+
# Also might need to add relu merge into dense_latency.h
154+
mult_params['out_t'] = node.intermediate_op.type.name + '::value_type' if node.model.config.get_config_value('IOType') == 'io_stream' else node.intermediate_op.type.name
146155
mult_config = self.mult_template.format(**mult_params)
147156

148157
return mult_config + '\n' + conv_config

hls4ml/backends/vivado/passes/core_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
typedef {bias_t.name} bias_t;
2020
typedef {weight_t.name} weight_t;
2121
typedef {index_t.name} index_t;
22-
typedef {out_t}:: value_type out_t;
22+
typedef {out_t} out_t;
2323
template<class x_T, class y_T>
2424
using product = nnet::product::{product_type}<x_T, y_T>;
2525
}};\n"""

0 commit comments

Comments
 (0)