Skip to content

Commit 4d44359

Browse files
GiuseppeDiGuglielmojmduarte
authored andcommitted
Update pragmas (fix a bug), load weights for simulation in the wrapper, and add _LOOP labels
1 parent a3f134f commit 4d44359

File tree

2 files changed

+54
-0
lines changed

2 files changed

+54
-0
lines changed

hls4ml/templates/vivado_accelerator/myproject_axi.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ void myproject(
99

1010
//hls-fpga-machine-learning insert local vars
1111

12+
#ifndef __SYNTHESIS__
13+
static bool loaded_weights = false;
14+
if (!loaded_weights) {
15+
//hls-fpga-machine-learning insert load weights
16+
loaded_weights = true;
17+
}
18+
#endif
19+
1220
//hls-fpga-machine-learning insert enqueue
1321

1422
//hls-fpga-machine-learning insert call

hls4ml/writer/vivado_accelerator_writer.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def write_axi_wrapper(self, model):
9191
newline = '#include "{}_axi.h"\n'.format(model.config.get_project_name())
9292
for b in model_brams:
9393
newline += '#include "weights/{}.h"\n'.format(b.name)
94+
newline += '\n'
95+
newline += '#include "parameters.h"\n'
9496
elif '//hls-fpga-machine-learning insert local vars' in line:
9597
newline = ''
9698
if self.vivado_accelerator_config.get_interface() == 'axi_stream':
@@ -128,6 +130,16 @@ def write_axi_wrapper(self, model):
128130
newline += indent + '#pragma HLS INTERFACE ap_ctrl_none port=return\n'
129131
if model.config.get_config_value("IOType") == 'io_stream':
130132
newline += indent + '#pragma HLS DATAFLOW\n'
133+
elif '//hls-fpga-machine-learning insert load weights' in line:
134+
newline = ''
135+
for layer in model.get_layers():
136+
for w in layer.get_weights():
137+
if w.weight_class == 'CompressedWeightVariable':
138+
newline += indent + ' nnet::load_compressed_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format(w.type.name, w.nonzeros, w.name, w.name)
139+
elif w.weight_class == 'ExponentWeightVariable':
140+
newline += indent + ' nnet::load_exponent_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format(w.type.name, w.data_length, w.name, w.name)
141+
else:
142+
newline += indent + ' nnet::load_weights_from_txt<{}, {}>({}, "{}.txt");\n'.format(w.type.name, w.data_length, w.name, w.name)
131143
elif '//hls-fpga-machine-learning insert enqueue' in line:
132144
io_type = model.config.get_config_value("IOType")
133145
if io_type == 'io_parallel':
@@ -143,10 +155,12 @@ def write_axi_wrapper(self, model):
143155
newline += indent + '}\n'
144156
elif io_type == 'io_stream':
145157
newline = ''
158+
newline += 'LOAD_INPUT_OUTER_LOOP:\n'
146159
newline += indent + 'for(unsigned i = 0; i < N_IN / {input_t}::size; ++i) {{\n'
147160
# newline += indent + indent + '#pragma HLS PIPELINE\n'
148161
newline += indent + indent + '{input_t} ctype;\n'
149162
newline += indent + indent + '#pragma HLS DATA_PACK variable=ctype\n'
163+
newline += 'LOAD_INPUT_INNER_LOOP:\n'
150164
newline += indent + indent + 'for(unsigned j = 0; j < {input_t}::size; j++) {{\n'
151165
# newline += indent + indent + indent + '#pragma HLS UNROLL\n'
152166
if self.vivado_accelerator_config.get_interface() == 'axi_stream':
@@ -173,9 +187,11 @@ def write_axi_wrapper(self, model):
173187
newline += indent + '}\n'
174188
elif io_type == 'io_stream':
175189
newline = ''
190+
newline += 'STORE_OUTPUT_OUTER_LOOP:\n'
176191
newline += indent + 'for(unsigned i = 0; i < N_OUT / {result_t}::size; ++i) {{\n'
177192
# newline += indent + indent + '#pragma HLS PIPELINE\n'
178193
newline += indent + indent + '{result_t} ctype = out_local.read();\n'
194+
newline += 'STORE_OUTPUT_INNER_LOOP:\n'
179195
newline += indent + indent + 'for(unsigned j = 0; j < {result_t}::size; j++) {{\n'
180196
# newline += indent + indent + indent + '#pragma HLS UNROLL\n'
181197
if self.vivado_accelerator_config.get_interface() == 'axi_stream':
@@ -192,6 +208,35 @@ def write_axi_wrapper(self, model):
192208
f.close()
193209
fout.close()
194210

211+
def modify_project_cpp(self, model):
212+
'''
213+
Modify the build_prj.tcl and build_lib.sh scripts to add the extra wrapper files and set the top function
214+
'''
215+
filedir = os.path.dirname(os.path.abspath(__file__))
216+
oldfile = '{}/firmware/{}.cpp'.format(model.config.get_output_dir(), model.config.get_project_name())
217+
newfile = '{}/build_prj_axi.tcl'.format(model.config.get_output_dir())
218+
f = open(oldfile, 'r')
219+
fout = open(newfile, 'w')
220+
221+
for line in f.readlines():
222+
if '#pragma HLS INTERFACE axis port=' in line:
223+
newline = ''
224+
elif '#pragma HLS INTERFACE bram port=' in line:
225+
newline = ''
226+
elif 'nnet::load_weights_from_txt' in line:
227+
newline = ''
228+
elif 'nnet::load_exponent_weights_from_txt' in line:
229+
newline = ''
230+
elif 'nnet::load_compressed_weights_from_txt' in line:
231+
newline = ''
232+
else:
233+
newline = line
234+
fout.write(newline)
235+
236+
f.close()
237+
fout.close()
238+
os.rename(newfile, oldfile)
239+
195240
def modify_build_script(self, model):
196241
'''
197242
Modify the build_prj.tcl and build_lib.sh scripts to add the extra wrapper files and set the top function
@@ -373,6 +418,7 @@ def write_hls(self, model):
373418
self.write_driver(model)
374419
self.write_wrapper_test(model)
375420
self.write_axi_wrapper(model)
421+
self.modify_project_cpp(model)
376422
self.modify_build_script(model)
377423
self.write_new_tar(model)
378424

0 commit comments

Comments
 (0)