Skip to content

Commit 06f9cda

Browse files
committed
format
1 parent 4854423 commit 06f9cda

File tree

2 files changed

+30
-21
lines changed

2 files changed

+30
-21
lines changed

hls4ml/converters/onnx_to_hls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def get_input_shape(graph, node):
7676
def get_constant_value(graph, constant_name):
7777
tensor = next((x for x in graph.initializer if x.name == constant_name), None)
7878
from onnx import numpy_helper
79+
7980
return numpy_helper.to_array(tensor)
8081

8182

@@ -274,6 +275,7 @@ def onnx_to_hls(config):
274275
print('Interpreting Model ...')
275276

276277
import onnx
278+
277279
onnx_model = onnx.load(config['OnnxModel']) if isinstance(config['OnnxModel'], str) else config['OnnxModel']
278280

279281
layer_list, input_layers, output_layers = parse_onnx_model(onnx_model)

hls4ml/writer/oneapi_writer.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,10 @@ def write_project_cpp(self, model):
102102
project_name = model.config.get_project_name()
103103

104104
filedir = os.path.dirname(os.path.abspath(__file__))
105-
with open(os.path.join(filedir, '../templates/oneapi/firmware/myproject.cpp')) as f, open(
106-
f'{model.config.get_output_dir()}/src/firmware/{project_name}.cpp', 'w'
107-
) as fout:
105+
with (
106+
open(os.path.join(filedir, '../templates/oneapi/firmware/myproject.cpp')) as f,
107+
open(f'{model.config.get_output_dir()}/src/firmware/{project_name}.cpp', 'w') as fout,
108+
):
108109
model_inputs = model.get_input_variables()
109110
model_outputs = model.get_output_variables()
110111
model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram']
@@ -207,9 +208,10 @@ def write_project_header(self, model):
207208
project_name = model.config.get_project_name()
208209

209210
filedir = os.path.dirname(os.path.abspath(__file__))
210-
with open(os.path.join(filedir, '../templates/oneapi/firmware/myproject.h')) as f, open(
211-
f'{model.config.get_output_dir()}/src/firmware/{project_name}.h', 'w'
212-
) as fout:
211+
with (
212+
open(os.path.join(filedir, '../templates/oneapi/firmware/myproject.h')) as f,
213+
open(f'{model.config.get_output_dir()}/src/firmware/{project_name}.h', 'w') as fout,
214+
):
213215
model_inputs = model.get_input_variables()
214216
model_outputs = model.get_output_variables()
215217
# model_brams = [var for var in model.get_weight_variables() if var.storage.lower() == 'bram']
@@ -254,9 +256,10 @@ def write_defines(self, model):
254256
model (ModelGraph): the hls4ml model.
255257
"""
256258
filedir = os.path.dirname(os.path.abspath(__file__))
257-
with open(os.path.join(filedir, '../templates/oneapi/firmware/defines.h')) as f, open(
258-
f'{model.config.get_output_dir()}/src/firmware/defines.h', 'w'
259-
) as fout:
259+
with (
260+
open(os.path.join(filedir, '../templates/oneapi/firmware/defines.h')) as f,
261+
open(f'{model.config.get_output_dir()}/src/firmware/defines.h', 'w') as fout,
262+
):
260263
for line in f.readlines():
261264
# Insert numbers
262265
if '// hls-fpga-machine-learning insert numbers' in line:
@@ -298,9 +301,10 @@ def write_parameters(self, model):
298301
model (ModelGraph): the hls4ml model.
299302
"""
300303
filedir = os.path.dirname(os.path.abspath(__file__))
301-
with open(os.path.join(filedir, '../templates/oneapi/firmware/parameters.h')) as f, open(
302-
f'{model.config.get_output_dir()}/src/firmware/parameters.h', 'w'
303-
) as fout:
304+
with (
305+
open(os.path.join(filedir, '../templates/oneapi/firmware/parameters.h')) as f,
306+
open(f'{model.config.get_output_dir()}/src/firmware/parameters.h', 'w') as fout,
307+
):
304308
for line in f.readlines():
305309
if '// hls-fpga-machine-learning insert includes' in line:
306310
newline = line
@@ -376,9 +380,10 @@ def write_test_bench(self, model):
376380
output_predictions, f'{model.config.get_output_dir()}/tb_data/tb_output_predictions.dat'
377381
)
378382

379-
with open(os.path.join(filedir, '../templates/oneapi/myproject_test.cpp')) as f, open(
380-
f'{model.config.get_output_dir()}/src/{project_name}_test.cpp', 'w'
381-
) as fout:
383+
with (
384+
open(os.path.join(filedir, '../templates/oneapi/myproject_test.cpp')) as f,
385+
open(f'{model.config.get_output_dir()}/src/{project_name}_test.cpp', 'w') as fout,
386+
):
382387
for line in f.readlines():
383388
indent = ' ' * (len(line) - len(line.lstrip(' ')))
384389

@@ -434,9 +439,10 @@ def write_bridge(self, model):
434439
indent = ' '
435440

436441
filedir = os.path.dirname(os.path.abspath(__file__))
437-
with open(os.path.join(filedir, '../templates/oneapi/myproject_bridge.cpp')) as f, open(
438-
f'{model.config.get_output_dir()}/src/{project_name}_bridge.cpp', 'w'
439-
) as fout:
442+
with (
443+
open(os.path.join(filedir, '../templates/oneapi/myproject_bridge.cpp')) as f,
444+
open(f'{model.config.get_output_dir()}/src/{project_name}_bridge.cpp', 'w') as fout,
445+
):
440446
for line in f.readlines():
441447
if 'MYPROJECT' in line:
442448
newline = line.replace('MYPROJECT', format(project_name.upper()))
@@ -511,9 +517,10 @@ def write_build_script(self, model):
511517
# Makefile
512518
filedir = os.path.dirname(os.path.abspath(__file__))
513519
device = model.config.get_config_value('Part')
514-
with open(os.path.join(filedir, '../templates/oneapi/CMakeLists.txt')) as f, open(
515-
f'{model.config.get_output_dir()}/CMakeLists.txt', 'w'
516-
) as fout:
520+
with (
521+
open(os.path.join(filedir, '../templates/oneapi/CMakeLists.txt')) as f,
522+
open(f'{model.config.get_output_dir()}/CMakeLists.txt', 'w') as fout,
523+
):
517524
for line in f.readlines():
518525
line = line.replace('myproject', model.config.get_project_name())
519526
line = line.replace('mystamp', model.config.get_config_value('Stamp'))

0 commit comments

Comments
 (0)