@@ -61,8 +61,7 @@ def build(
6161 debug = False ,
6262 ** kwargs ,
6363 ):
64- if target not in ["hw" , "hw_emu" , "sw_emu" ]:
65- raise Exception ("Invalid target, must be one of 'hw', 'hw_emu' or 'sw_emu'" )
64+ self ._validate_target (target )
6665
6766 if "linux" in sys .platform :
6867
@@ -113,12 +112,12 @@ def dat_to_numpy(self, model):
113112 y = np .loadtxt (output_file , dtype = float ).reshape (- 1 , expected_shape )
114113 return y
115114
116- def hardware_predict (self , model , x ):
115+ def hardware_predict (self , model , x , target = "hw" ):
116+ self ._validate_target (target )
117117 self .numpy_to_dat (model , x )
118-
119118 currdir = os .getcwd ()
120119 os .chdir (model .config .get_output_dir ())
121- os .system (" make run" )
120+ os .system (f"TARGET= { target } make run" )
122121 os .chdir (currdir )
123122
124123 return self .dat_to_numpy (model )
@@ -151,3 +150,7 @@ def _register_flows(self):
151150 ip_flow_requirements .insert (ip_flow_requirements .index ("vivado:apply_templates" ), template_flow )
152151
153152 self ._default_flow = register_flow ("ip" , None , requires = ip_flow_requirements , backend = self .name )
153+
154+ def _validate_target (self , target ):
155+ if target not in ["hw" , "hw_emu" , "sw_emu" ]:
156+ raise Exception ("Invalid target, must be one of 'hw', 'hw_emu' or 'sw_emu'" )
0 commit comments