@@ -322,6 +322,34 @@ def write_project_header(self, model):
322322 if namespace is not None :
323323 newline += '}\n '
324324
325+ elif '// hls-fpga-machine-learning insert emulator-defines' in line :
326+ newline = line
327+
328+ if model .config .get_writer_config ().get ('WriteEmulationConstants' , False ):
329+ brams_def_str = ', ' .join ([b .definition_cpp (as_reference = False ) for b in model_brams ])
330+ brams_call_str = ', ' .join ([b .name for b in model_brams ])
331+
332+ if model .config .get_config_value ('IOType' ) == 'io_stream' :
333+ input_call_str = ', ' .join ([f'std::get<{ n } >(inputs)' for n in range (len (model_inputs ))])
334+ output_call_str = ', ' .join ([f'std::get<{ n } >(outputs)' for n in range (len (model_outputs ))])
335+ else :
336+ input_call_str = ', ' .join ([f'std::get<{ n } >(inputs).data()' for n in range (len (model_inputs ))])
337+ output_call_str = ', ' .join ([f'std::get<{ n } >(outputs).data()' for n in range (len (model_outputs ))])
338+
339+ newline += (
340+ f'\n inline void { model .config .get_project_name ()} _emulator('
341+ 'inputs_t& inputs, outputs_t& outputs' # the inputs_t should ideally be const
342+ )
343+ if len (model_brams ) > 0 :
344+ newline += ',\n ' + brams_def_str
345+ newline += ') {\n '
346+ newline += indent + model .config .get_project_name () + '(\n '
347+ newline += indent + indent + input_call_str + ',\n '
348+ newline += indent + indent + output_call_str
349+ if len (model_brams ) > 0 :
350+ newline += ',\n ' + indent + indent + brams_call_str
351+ newline += '\n ' + indent + ');\n }\n '
352+
325353 else :
326354 newline = line
327355 fout .write (newline )
@@ -385,6 +413,20 @@ def write_defines(self, model):
385413 if namespace is not None :
386414 newline += '}\n '
387415
416+ elif '// hls-fpga-machine-learning insert emulator-defines' in line :
417+ newline = line
418+
419+ if model .config .get_writer_config ().get ('WriteEmulationConstants' , False ):
420+ if model .config .get_config_value ('IOType' ) == 'io_stream' :
421+ input_types = [f'hls::stream<{ v .type .name } >' for v in model .get_input_variables ()]
422+ output_types = [f'hls::stream<{ v .type .name } >' for v in model .get_output_variables ()]
423+ else :
424+ input_types = [f'std::array<{ v .type .name } , { v .size_cpp ()} >' for v in model .get_input_variables ()]
425+ output_types = [f'std::array<{ v .type .name } , { v .size_cpp ()} >' for v in model .get_output_variables ()]
426+ input_types_str = ', ' .join (input_types )
427+ output_types_str = ', ' .join (output_types )
428+ newline += '\n ' + f'using inputs_t = std::tuple<{ input_types_str } >;'
429+ newline += '\n ' + f'using outputs_t = std::tuple<{ output_types_str } >;\n '
388430 else :
389431 newline = line
390432 fout .write (newline )
0 commit comments