Skip to content

Commit cd915eb

Browse files
committed
Use smaller model in Conv1D test
1 parent e9bb7ff commit cd915eb

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

hls4ml/backends/quartus/quartus_backend.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,8 @@ def init_conv1d(self, layer):
254254
# - Winograd - use Winograd, if possible
255255
layer.set_attr('implementation', layer.model.config.get_layer_config_value(layer, 'Implementation', 'combination'))
256256

257+
layer.set_attr('n_partitions', 1) #TODO Not used yet as there is no codegen implementation of CNNs for Quartus backend
258+
257259
@layer_optimizer(Conv2D)
258260
def init_conv2d(self, layer):
259261
# This can happen if we assign weights of Dense layer to 1x1 Conv2D
@@ -280,6 +282,8 @@ def init_conv2d(self, layer):
280282
# - im2col - specifically use im2col
281283
# - Winograd - use Winograd, if possible
282284
layer.set_attr('implementation', layer.model.config.get_layer_config_value(layer, 'Implementation', 'combination'))
285+
286+
layer.set_attr('n_partitions', 1) #TODO Not used yet as there is no codegen implementation of CNNs for Quartus backend
283287

284288
@layer_optimizer(LSTM)
285289
def init_lstm(self, layer):

test/pytest/test_conv1d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,16 @@
1010

1111
@pytest.fixture(scope='module')
1212
def data():
13-
X = np.random.rand(100,100,7)
13+
X = np.random.rand(100,10,4)
1414
return X
1515

1616
@pytest.fixture(scope='module')
1717
def keras_model():
18-
model_path = example_model_path / 'keras/KERAS_conv1d.json'
18+
model_path = example_model_path / 'keras/KERAS_conv1d_small.json'
1919
with model_path.open('r') as f:
2020
jsons = f.read()
2121
model = model_from_json(jsons)
22-
model.load_weights(example_model_path / 'keras/KERAS_conv1d_weights.h5')
22+
model.load_weights(example_model_path / 'keras/KERAS_conv1d_small_weights.h5')
2323
return model
2424

2525
@pytest.fixture

0 commit comments

Comments
 (0)