Skip to content

Commit 8841039

Browse files
author
Quentin Berthet
committed
Add a target parameter to hardware_predict()
1 parent e82f2c9 commit 8841039

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

hls4ml/backends/vitis_accelerator/vitis_accelerator_backend.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ def build(
6161
debug=False,
6262
**kwargs,
6363
):
64-
if target not in ["hw", "hw_emu", "sw_emu"]:
65-
raise Exception("Invalid target, must be one of 'hw', 'hw_emu' or 'sw_emu'")
64+
self._validate_target(target)
6665

6766
if "linux" in sys.platform:
6867

@@ -113,12 +112,12 @@ def dat_to_numpy(self, model):
113112
y = np.loadtxt(output_file, dtype=float).reshape(-1, expected_shape)
114113
return y
115114

116-
def hardware_predict(self, model, x):
115+
def hardware_predict(self, model, x, target="hw"):
116+
self._validate_target(target)
117117
self.numpy_to_dat(model, x)
118-
119118
currdir = os.getcwd()
120119
os.chdir(model.config.get_output_dir())
121-
os.system("make run")
120+
os.system(f"TARGET={target} make run")
122121
os.chdir(currdir)
123122

124123
return self.dat_to_numpy(model)
@@ -151,3 +150,7 @@ def _register_flows(self):
151150
ip_flow_requirements.insert(ip_flow_requirements.index("vivado:apply_templates"), template_flow)
152151

153152
self._default_flow = register_flow("ip", None, requires=ip_flow_requirements, backend=self.name)
153+
154+
def _validate_target(self, target):
155+
if target not in ["hw", "hw_emu", "sw_emu"]:
156+
raise Exception("Invalid target, must be one of 'hw', 'hw_emu' or 'sw_emu'")

hls4ml/model/graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -865,13 +865,13 @@ class TraceData(ctypes.Structure):
865865
else:
866866
return output, trace_output
867867

868-
def hardware_predict(self, x):
868+
def hardware_predict(self, x, **kwargs):
869869
"""Currently only supported for VitisAccelerator backend"""
870870
backend = self.config.config.get('Backend', 'Vivado')
871871
if backend != 'VitisAccelerator':
872872
raise Exception(f"Function unsupported for {backend} backend")
873873

874-
return self.config.backend.hardware_predict(self, x)
874+
return self.config.backend.hardware_predict(self, x, **kwargs)
875875

876876
def build(self, **kwargs):
877877
"""Builds the generated project using HLS compiler.

0 commit comments

Comments
 (0)