|
14 | 14 | @pytest.fixture(scope='module') |
15 | 15 | def keras_model(): |
16 | 16 | model = Sequential() |
17 | | - model.add(Dense(10, activation='softmax', input_shape=(15,))) |
| 17 | + model.add(Dense(10, kernel_initializer='zeros', use_bias=False, input_shape=(15,))) |
18 | 18 | model.compile() |
19 | 19 | return model |
20 | 20 |
|
@@ -69,3 +69,40 @@ def test_write_weights_txt(keras_model, write_weights_txt, backend): |
69 | 69 |
|
70 | 70 | txt_written = os.path.exists(odir + '/firmware/weights/w2.txt') |
71 | 71 | assert txt_written == write_weights_txt |
| 72 | + |
| 73 | + |
| 74 | +@pytest.mark.skip(reason='Skipping for now as it needs the installation of the compiler.') |
| 75 | +@pytest.mark.parametrize('backend', ['Vivado', 'Vitis']) |
| 76 | +@pytest.mark.parametrize('tb_output_stream', ['stdout', 'file', 'both']) |
| 77 | +def test_tb_output_stream(capfd, keras_model, tb_output_stream, backend): |
| 78 | + |
| 79 | + config = hls4ml.utils.config_from_keras_model(keras_model, granularity='name') |
| 80 | + odir = str(test_root_path / f'hls4mlprj_tb_output_stream_{tb_output_stream}_{backend}') |
| 81 | + if os.path.exists(odir): |
| 82 | + shutil.rmtree(odir) |
| 83 | + |
| 84 | + hls_model = hls4ml.converters.convert_from_keras_model( |
| 85 | + keras_model, |
| 86 | + io_type='io_stream', |
| 87 | + hls_config=config, |
| 88 | + output_dir=odir, |
| 89 | + backend=backend, |
| 90 | + tb_output_stream=tb_output_stream, |
| 91 | + ) |
| 92 | + hls_model.build(csim=True, synth=False) |
| 93 | + |
| 94 | + # Check the output based on tb_output_stream |
| 95 | + tb_file_path = os.path.join(odir, 'tb_data/csim_results.log') |
| 96 | + |
| 97 | + with open(tb_file_path) as tb_file: |
| 98 | + tb_content = tb_file.read() |
| 99 | + if tb_output_stream in ['file', 'both']: |
| 100 | + assert len(tb_content) > 0, 'Testbench output file expected to contain model outputs, but is empty' |
| 101 | + else: |
| 102 | + assert len(tb_content) == 0, 'Testbench output file expected to be empty, but contains data' |
| 103 | + |
| 104 | + captured = capfd.readouterr() |
| 105 | + if tb_output_stream in ['stdout', 'both']: |
| 106 | + assert '0 0 0 0 0 0 0 0 0 0' in captured.out, 'Expected model output not found in stdout' |
| 107 | + else: |
| 108 | + assert '0 0 0 0 0 0 0 0 0 0' not in captured.out, 'Model output should not be printed to stdout' |
0 commit comments