Skip to content

Commit 5fbdae8

Browse files
committed
Hardcode weights loading (ensures weights loading works from any dir)
1 parent afed23b commit 5fbdae8

File tree

4 files changed

+21
-23
lines changed

4 files changed

+21
-23
lines changed

hls4ml/model/graph.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -805,32 +805,24 @@ def predict(self, x):
805805
n_inputs = len(self.get_input_variables())
806806
n_outputs = len(self.get_output_variables())
807807

808-
curr_dir = os.getcwd()
809-
os.chdir(self.config.get_output_dir() + '/firmware')
810-
811808
output = []
812809
if n_samples == 1 and n_inputs == 1:
813810
x = [x]
814811

815-
try:
816-
for i in range(n_samples):
817-
predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()]
818-
if n_inputs == 1:
819-
inp = [np.asarray(x[i])]
820-
else:
821-
inp = [np.asarray(xj[i]) for xj in x]
822-
argtuple = inp
823-
argtuple += predictions
824-
argtuple = tuple(argtuple)
825-
top_function(*argtuple)
826-
output.append(predictions)
827-
828-
# Convert to list of numpy arrays (one for each output)
829-
output = [
830-
np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs)
831-
]
832-
finally:
833-
os.chdir(curr_dir)
812+
for i in range(n_samples):
813+
predictions = [np.zeros(yj.size(), dtype=ctype) for yj in self.get_output_variables()]
814+
if n_inputs == 1:
815+
inp = [np.asarray(x[i])]
816+
else:
817+
inp = [np.asarray(xj[i]) for xj in x]
818+
argtuple = inp
819+
argtuple += predictions
820+
argtuple = tuple(argtuple)
821+
top_function(*argtuple)
822+
output.append(predictions)
823+
824+
# Convert to list of numpy arrays (one for each output)
825+
output = [np.asarray([output[i_sample][i_output] for i_sample in range(n_samples)]) for i_output in range(n_outputs)]
834826

835827
if n_samples == 1 and n_outputs == 1:
836828
return output[0][0]

hls4ml/templates/catapult/myproject_bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <algorithm>
77
#include <map>
88

9-
static std::string s_weights_dir = "weights";
9+
// hls-fpga-machine-learning insert weights dir
1010

1111
const char *get_weights_dir() { return s_weights_dir.c_str(); }
1212

hls4ml/writer/catapult_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,9 @@ def write_bridge(self, model):
676676
newline = line.replace('MYPROJECT', format(model.config.get_project_name().upper()))
677677
elif 'myproject' in line:
678678
newline = line.replace('myproject', format(model.config.get_project_name()))
679+
elif '// hls-fpga-machine-learning insert weights dir' in line:
680+
weights_dir = (Path(fout.name).parent / 'firmware/weights').resolve()
681+
newline = f'static std::string s_weights_dir = "{weights_dir}";\n'
679682
elif '// hls-fpga-machine-learning insert bram' in line:
680683
newline = line
681684
for bram in model_brams:

hls4ml/writer/vivado_writer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,10 +725,13 @@ def write_build_script(self, model):
725725
# build_lib.sh
726726
build_lib_src = (filedir / '../templates/vivado/build_lib.sh').resolve()
727727
build_lib_dst = Path(f'{model.config.get_output_dir()}/build_lib.sh').resolve()
728+
weights_dir = (build_lib_dst.parent / 'firmware/weights').resolve()
728729
with open(build_lib_src) as src, open(build_lib_dst, 'w') as dst:
729730
for line in src.readlines():
730731
line = line.replace('myproject', model.config.get_project_name())
731732
line = line.replace('mystamp', model.config.get_config_value('Stamp'))
733+
if line.startswith('WEIGHTS_DIR='):
734+
line = f'WEIGHTS_DIR=\\""{weights_dir}\\""\n'
732735

733736
dst.write(line)
734737
build_lib_dst.chmod(build_lib_dst.stat().st_mode | stat.S_IEXEC)

0 commit comments

Comments
 (0)