Skip to content

Commit 6d2c81c

Browse files
committed
Fixed error with multi-input neural networks and batches.
Also added a test for this case.
1 parent ca375a6 commit 6d2c81c

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

source/FAST/Algorithms/NeuralNetwork/NeuralNetwork.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ std::unordered_map<std::string, Tensor::pointer> NeuralNetwork::processInputData
179179

180180
if(m_batchSize == -1) {
181181
m_batchSize = dataList.getSize();
182-
} else {
182+
} else if(m_batchSize != dataList.getSize()) {
183183
throw Exception("Inconsistent batch size accross input nodes");
184184
}
185185
} else {

source/FAST/Algorithms/NeuralNetwork/Tests.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,31 @@ TEST_CASE("Multi input single output network", "[fast][neuralnetwork]") {
152152
}
153153
}
154154

155+
TEST_CASE("Multi input single output network with batch", "[fast][neuralnetwork]") {
156+
for(auto& engine : InferenceEngineManager::getEngineList()) {
157+
auto importer = ImageFileImporter::New();
158+
importer->setFilename(Config::getTestDataPath() + "US/JugularVein/US-2D_0.mhd");
159+
auto image = importer->runAndGetOutputData<Image>();
160+
auto batch1 = Batch::create({image, image});
161+
auto batch2 = Batch::create({image, image});
162+
163+
auto network = NeuralNetwork::New();
164+
network->setInferenceEngine(engine);
165+
network->load(join(Config::getTestDataPath(),
166+
"NeuralNetworkModels/multi_input_single_output." +
167+
getModelFileExtension(network->getInferenceEngine()->getPreferredModelFormat())));
168+
network->connect(0, batch1);
169+
network->connect(1, batch2);
170+
auto batch = network->runAndGetOutputData<Batch>();
171+
auto list = batch->get().getTensors();
172+
REQUIRE(list.size() == 2);
173+
auto data = list[0];
174+
// We are expecting a tensor output with dimensions (6)
175+
REQUIRE(data->getShape().getDimensions() == 1);
176+
CHECK(data->getShape()[0] == 6);
177+
}
178+
}
179+
155180
TEST_CASE("Single input multi output network", "[fast][neuralnetwork]") {
156181
for(auto& engine : InferenceEngineManager::getEngineList()) {
157182
#ifdef WIN32

0 commit comments

Comments
 (0)