@@ -91,6 +91,8 @@ def write_axi_wrapper(self, model):
9191 newline = '#include "{}_axi.h"\n ' .format (model .config .get_project_name ())
9292 for b in model_brams :
9393 newline += '#include "weights/{}.h"\n ' .format (b .name )
94+ newline += '\n '
95+ newline += '#include "parameters.h"\n '
9496 elif '//hls-fpga-machine-learning insert local vars' in line :
9597 newline = ''
9698 if self .vivado_accelerator_config .get_interface () == 'axi_stream' :
@@ -128,6 +130,16 @@ def write_axi_wrapper(self, model):
128130 newline += indent + '#pragma HLS INTERFACE ap_ctrl_none port=return\n '
129131 if model .config .get_config_value ("IOType" ) == 'io_stream' :
130132 newline += indent + '#pragma HLS DATAFLOW\n '
133+ elif '//hls-fpga-machine-learning insert load weights' in line :
134+ newline = ''
135+ for layer in model .get_layers ():
136+ for w in layer .get_weights ():
137+ if w .weight_class == 'CompressedWeightVariable' :
138+ newline += indent + ' nnet::load_compressed_weights_from_txt<{}, {}>({}, "{}.txt");\n ' .format (w .type .name , w .nonzeros , w .name , w .name )
139+ elif w .weight_class == 'ExponentWeightVariable' :
140+ newline += indent + ' nnet::load_exponent_weights_from_txt<{}, {}>({}, "{}.txt");\n ' .format (w .type .name , w .data_length , w .name , w .name )
141+ else :
142+ newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n ' .format (w .type .name , w .data_length , w .name , w .name )
131143 elif '//hls-fpga-machine-learning insert enqueue' in line :
132144 io_type = model .config .get_config_value ("IOType" )
133145 if io_type == 'io_parallel' :
@@ -143,10 +155,12 @@ def write_axi_wrapper(self, model):
143155 newline += indent + '}\n '
144156 elif io_type == 'io_stream' :
145157 newline = ''
158+ newline += 'LOAD_INPUT_OUTER_LOOP:\n '
146159 newline += indent + 'for(unsigned i = 0; i < N_IN / {input_t}::size; ++i) {{\n '
147160 # newline += indent + indent + '#pragma HLS PIPELINE\n'
148161 newline += indent + indent + '{input_t} ctype;\n '
149162 newline += indent + indent + '#pragma HLS DATA_PACK variable=ctype\n '
163+ newline += 'LOAD_INPUT_INNER_LOOP:\n '
150164 newline += indent + indent + 'for(unsigned j = 0; j < {input_t}::size; j++) {{\n '
151165 # newline += indent + indent + indent + '#pragma HLS UNROLL\n'
152166 if self .vivado_accelerator_config .get_interface () == 'axi_stream' :
@@ -173,9 +187,11 @@ def write_axi_wrapper(self, model):
173187 newline += indent + '}\n '
174188 elif io_type == 'io_stream' :
175189 newline = ''
190+ newline += 'STORE_OUTPUT_OUTER_LOOP:\n '
176191 newline += indent + 'for(unsigned i = 0; i < N_OUT / {result_t}::size; ++i) {{\n '
177192 # newline += indent + indent + '#pragma HLS PIPELINE\n'
178193 newline += indent + indent + '{result_t} ctype = out_local.read();\n '
194+ newline += 'STORE_OUTPUT_INNER_LOOP:\n '
179195 newline += indent + indent + 'for(unsigned j = 0; j < {result_t}::size; j++) {{\n '
180196 # newline += indent + indent + indent + '#pragma HLS UNROLL\n'
181197 if self .vivado_accelerator_config .get_interface () == 'axi_stream' :
@@ -192,6 +208,35 @@ def write_axi_wrapper(self, model):
192208 f .close ()
193209 fout .close ()
194210
211+ def modify_project_cpp (self , model ):
212+ '''
213+ Modify the build_prj.tcl and build_lib.sh scripts to add the extra wrapper files and set the top function
214+ '''
215+ filedir = os .path .dirname (os .path .abspath (__file__ ))
216+ oldfile = '{}/firmware/{}.cpp' .format (model .config .get_output_dir (), model .config .get_project_name ())
217+ newfile = '{}/build_prj_axi.tcl' .format (model .config .get_output_dir ())
218+ f = open (oldfile , 'r' )
219+ fout = open (newfile , 'w' )
220+
221+ for line in f .readlines ():
222+ if '#pragma HLS INTERFACE axis port=' in line :
223+ newline = ''
224+ elif '#pragma HLS INTERFACE bram port=' in line :
225+ newline = ''
226+ elif 'nnet::load_weights_from_txt' in line :
227+ newline = ''
228+ elif 'nnet::load_exponent_weights_from_txt' in line :
229+ newline = ''
230+ elif 'nnet::load_compressed_weights_from_txt' in line :
231+ newline = ''
232+ else :
233+ newline = line
234+ fout .write (newline )
235+
236+ f .close ()
237+ fout .close ()
238+ os .rename (newfile , oldfile )
239+
195240 def modify_build_script (self , model ):
196241 '''
197242 Modify the build_prj.tcl and build_lib.sh scripts to add the extra wrapper files and set the top function
@@ -373,6 +418,7 @@ def write_hls(self, model):
373418 self .write_driver (model )
374419 self .write_wrapper_test (model )
375420 self .write_axi_wrapper (model )
421+ self .modify_project_cpp (model )
376422 self .modify_build_script (model )
377423 self .write_new_tar (model )
378424
0 commit comments