|
1 | 1 | import glob |
2 | 2 | import os |
| 3 | +import stat |
| 4 | +from pathlib import Path |
3 | 5 | from shutil import copyfile, copytree, rmtree |
4 | 6 |
|
5 | 7 | from hls4ml.backends import get_backend |
@@ -56,49 +58,48 @@ def write_build_script(self, model): |
56 | 58 | model (ModelGraph): the hls4ml model. |
57 | 59 | """ |
58 | 60 |
|
59 | | - filedir = os.path.dirname(os.path.abspath(__file__)) |
60 | | - |
61 | | - # build_prj.tcl |
62 | | - f = open(f'{model.config.get_output_dir()}/project.tcl', 'w') |
63 | | - f.write('variable project_name\n') |
64 | | - f.write(f'set project_name "{model.config.get_project_name()}"\n') |
65 | | - f.write('variable backend\n') |
66 | | - f.write('set backend "vivado"\n') |
67 | | - f.write('variable part\n') |
68 | | - f.write('set part "{}"\n'.format(model.config.get_config_value('Part'))) |
69 | | - f.write('variable clock_period\n') |
70 | | - f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod'))) |
71 | | - f.write('variable clock_uncertainty\n') |
72 | | - f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '0%'))) |
73 | | - f.write('variable version\n') |
74 | | - f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0'))) |
75 | | - f.close() |
| 61 | + filedir = Path(__file__).parent |
| 62 | + |
| 63 | + # project.tcl |
| 64 | + prj_tcl_dst = Path(f'{model.config.get_output_dir()}/project.tcl') |
| 65 | + with open(prj_tcl_dst, 'w') as f: |
| 66 | + f.write('variable project_name\n') |
| 67 | + f.write(f'set project_name "{model.config.get_project_name()}"\n') |
| 68 | + f.write('variable backend\n') |
| 69 | + f.write('set backend "vivado"\n') |
| 70 | + f.write('variable part\n') |
| 71 | + f.write('set part "{}"\n'.format(model.config.get_config_value('Part'))) |
| 72 | + f.write('variable clock_period\n') |
| 73 | + f.write('set clock_period {}\n'.format(model.config.get_config_value('ClockPeriod'))) |
| 74 | + f.write('variable clock_uncertainty\n') |
| 75 | + f.write('set clock_uncertainty {}\n'.format(model.config.get_config_value('ClockUncertainty', '0%'))) |
| 76 | + f.write('variable version\n') |
| 77 | + f.write('set version "{}"\n'.format(model.config.get_config_value('Version', '1.0.0'))) |
76 | 78 |
|
77 | 79 | # build_prj.tcl |
78 | | - srcpath = os.path.join(filedir, '../templates/vivado/build_prj.tcl') |
| 80 | + srcpath = (filedir / '../templates/vivado/build_prj.tcl').resolve() |
79 | 81 | dstpath = f'{model.config.get_output_dir()}/build_prj.tcl' |
80 | 82 | copyfile(srcpath, dstpath) |
81 | 83 |
|
82 | 84 | # vivado_synth.tcl |
83 | | - srcpath = os.path.join(filedir, '../templates/vivado/vivado_synth.tcl') |
| 85 | + srcpath = (filedir / '../templates/vivado/vivado_synth.tcl').resolve() |
84 | 86 | dstpath = f'{model.config.get_output_dir()}/vivado_synth.tcl' |
85 | 87 | copyfile(srcpath, dstpath) |
86 | 88 |
|
87 | 89 | # build_lib.sh |
88 | | - f = open(os.path.join(filedir, '../templates/symbolic/build_lib.sh')) |
89 | | - fout = open(f'{model.config.get_output_dir()}/build_lib.sh', 'w') |
90 | | - |
91 | | - for line in f.readlines(): |
92 | | - line = line.replace('myproject', model.config.get_project_name()) |
93 | | - line = line.replace('mystamp', model.config.get_config_value('Stamp')) |
94 | | - line = line.replace('mylibspath', model.config.get_config_value('HLSLibsPath')) |
95 | | - |
96 | | - if 'LDFLAGS=' in line and not os.path.exists(model.config.get_config_value('HLSLibsPath')): |
97 | | - line = 'LDFLAGS=\n' |
98 | | - |
99 | | - fout.write(line) |
100 | | - f.close() |
101 | | - fout.close() |
| 90 | + build_lib_src = (filedir / '../templates/symbolic/build_lib.sh').resolve() |
| 91 | + build_lib_dst = Path(f'{model.config.get_output_dir()}/build_lib.sh').resolve() |
| 92 | + with open(build_lib_src) as src, open(build_lib_dst, 'w') as dst: |
| 93 | + for line in src.readlines(): |
| 94 | + line = line.replace('myproject', model.config.get_project_name()) |
| 95 | + line = line.replace('mystamp', model.config.get_config_value('Stamp')) |
| 96 | + line = line.replace('mylibspath', model.config.get_config_value('HLSLibsPath')) |
| 97 | + |
| 98 | + if 'LDFLAGS=' in line and not os.path.exists(model.config.get_config_value('HLSLibsPath')): |
| 99 | + line = 'LDFLAGS=\n' |
| 100 | + |
| 101 | + dst.write(line) |
| 102 | + build_lib_dst.chmod(build_lib_dst.stat().st_mode | stat.S_IEXEC) |
102 | 103 |
|
103 | 104 | def write_hls(self, model): |
104 | 105 | print('Writing HLS project') |
|
0 commit comments