Skip to content

Commit e284b54

Browse files
committed
Subsampling connections fix
1 parent 6c2d718 commit e284b54

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

nn-core/src/main/java/com/github/neuralnetworks/architecture/ConnectionFactory.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ public Conv2DConnection conv2d(Layer inputLayer, Layer outputLayer, int inputFea
6969
return result;
7070
}
7171

72-
public Subsampling2DConnection subsampling2D(Layer inputLayer, Layer outputLayer, int inputFeatureMapColumns, int inputFeatureMapRows, int subsamplingRegionRows, int subsamplingRegionCols, int filters) {
73-
return new Subsampling2DConnection(inputLayer, outputLayer, inputFeatureMapColumns, inputFeatureMapRows, subsamplingRegionRows, subsamplingRegionCols, filters);
72+
public Subsampling2DConnection subsampling2D(Layer inputLayer, Layer outputLayer, int inputFeatureMapRows, int inputFeatureMapColumns, int subsamplingRegionRows, int subsamplingRegionCols, int filters) {
73+
return new Subsampling2DConnection(inputLayer, outputLayer, inputFeatureMapRows, inputFeatureMapColumns, subsamplingRegionRows, subsamplingRegionCols, filters);
7474
}
7575

7676
public boolean useSharedWeights() {

nn-core/src/main/java/com/github/neuralnetworks/architecture/Subsampling2DConnection.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ public class Subsampling2DConnection extends ConnectionsImpl {
1414
protected int outputFeatureMapRows;
1515
protected int filters;
1616

17-
public Subsampling2DConnection(Layer inputLayer, Layer outputLayer, int inputFeatureMapColumns, int inputFeatureMapRows, int subsamplingRegionRows, int subsamplingRegionCols, int filters) {
17+
public Subsampling2DConnection(Layer inputLayer, Layer outputLayer, int inputFeatureMapRows, int inputFeatureMapColumns, int subsamplingRegionRows, int subsamplingRegionCols, int filters) {
1818
super(inputLayer, outputLayer);
19-
this.inputFeatureMapColumns = inputFeatureMapColumns;
2019
this.inputFeatureMapRows = inputFeatureMapRows;
20+
this.inputFeatureMapColumns = inputFeatureMapColumns;
2121
this.filters = filters;
2222
setDimensions(subsamplingRegionRows, subsamplingRegionCols);
2323
}

nn-core/src/test/java/com/github/neuralnetworks/test/CNNTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,7 +654,7 @@ public void testCNNMLPFF() {
654654
ValuesProvider cnnvp = TensorFactory.tensorProvider(cnn, 1, Environment.getInstance().getUseDataSharedMemory());
655655
Tensor cnnin = cnnvp.get(cnn.getInputLayer());
656656
cnnin.set(0.2f, 0, 0, 0, 0);
657-
cnnin.set(0.6f, 0, 0, 1, 0);
657+
cnnin.set(0.6f, 0, 1, 0, 0);
658658

659659
// MLP
660660
NeuralNetworkImpl mlp = NNFactory.mlpSigmoid(new int[] { 2, 1 }, false);

0 commit comments

Comments
 (0)