Skip to content

Commit b6adcca

Browse files
committed
[MINOR] Fix python neural network tests (race condition random.seed)
1 parent edfce10 commit b6adcca

File tree

2 files changed

+66
-69
lines changed

2 files changed

+66
-69
lines changed

src/main/python/tests/nn/test_neural_network.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,13 @@
2525
from tests.nn.neural_network import NeuralNetwork
2626
from systemds.script_building.script import DMLScript
2727

28-
# Seed for the input matrix
29-
np.random.seed(42)
30-
31-
3228
class TestNeuralNetwork(unittest.TestCase):
3329
sds: SystemDSContext = None
3430

3531
@classmethod
3632
def setUpClass(cls):
3733
cls.sds = SystemDSContext()
34+
np.random.seed(42)
3835
cls.X = np.random.rand(6, 1)
3936
cls.exp_out = np.array([
4037
-0.37768756, -0.47785831, -0.95870362,

src/test/java/org/apache/sysds/test/component/matrix/libMatrixReorg/RollTest.java

Lines changed: 65 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -42,82 +42,82 @@
4242
*/
4343
@RunWith(Parameterized.class)
4444
public class RollTest {
45-
private final int shift;
45+
private final int shift;
4646

47-
// Input matrices
48-
private MatrixBlock inputSparse;
49-
private MatrixBlock inputDense;
47+
// Input matrices
48+
private MatrixBlock inputSparse;
49+
private MatrixBlock inputDense;
5050

51-
/**
52-
* Constructor for parameterized test cases.
53-
*
54-
* @param rows Number of rows in the test matrix.
55-
* @param cols Number of columns in the test matrix.
56-
* @param sparsity Sparsity level of the test matrix (0.0 to 1.0).
57-
* @param shift Shift value for the roll operation.
58-
*/
59-
public RollTest(int rows, int cols, double sparsity, int shift) {
60-
this.shift = shift;
51+
/**
52+
* Constructor for parameterized test cases.
53+
*
54+
* @param rows Number of rows in the test matrix.
55+
* @param cols Number of columns in the test matrix.
56+
* @param sparsity Sparsity level of the test matrix (0.0 to 1.0).
57+
* @param shift Shift value for the roll operation.
58+
*/
59+
public RollTest(int rows, int cols, double sparsity, int shift) {
60+
this.shift = shift;
6161

62-
// Generate a MatrixBlock with the given parameters
63-
inputSparse = TestUtils.generateTestMatrixBlock(rows, cols, 0, 10, sparsity, 1);
64-
inputSparse.recomputeNonZeros();
62+
// Generate a MatrixBlock with the given parameters
63+
inputSparse = TestUtils.generateTestMatrixBlock(rows, cols, 0, 10, sparsity, 1);
64+
inputSparse.recomputeNonZeros();
6565

66-
inputDense = new MatrixBlock(rows, cols, false); // false indicates dense
67-
inputDense.copy(inputSparse, false); // Copy without maintaining sparsity
68-
inputDense.recomputeNonZeros();
69-
}
66+
inputDense = new MatrixBlock(rows, cols, false); // false indicates dense
67+
inputDense.copy(inputSparse, false); // Copy without maintaining sparsity
68+
inputDense.recomputeNonZeros();
69+
}
7070

71-
/**
72-
* Defines the parameters for the test cases.
73-
* Each Object[] contains {rows, cols, sparsity, shift}.
74-
*
75-
* @return Collection of test parameters.
76-
*/
77-
@Parameters(name = "Rows: {0}, Cols: {1}, Sparsity: {2}, Shift: {3}")
78-
public static Collection<Object[]> data() {
79-
List<Object[]> tests = new ArrayList<>();
71+
/**
72+
* Defines the parameters for the test cases.
73+
* Each Object[] contains {rows, cols, sparsity, shift}.
74+
*
75+
* @return Collection of test parameters.
76+
*/
77+
@Parameters(name = "Rows: {0}, Cols: {1}, Sparsity: {2}, Shift: {3}")
78+
public static Collection<Object[]> data() {
79+
List<Object[]> tests = new ArrayList<>();
8080

81-
// Define various sizes, sparsity levels, and shift values
82-
int[] rows = {1, 19, 1001, 2017};
83-
int[] cols = {1, 17, 1001, 2017};
84-
double[] sparsities = {0.01, 0.1, 0.7, 1.0};
85-
int[] shifts = {0, 1, 5, 10, 15};
81+
// Define various sizes, sparsity levels, and shift values
82+
int[] rows = {1, 19, 1001, 2017};
83+
int[] cols = {1, 17, 1001, 2017};
84+
double[] sparsities = {0.01, 0.1, 0.7, 1.0};
85+
int[] shifts = {0, 1, 5, 10, 15};
8686

87-
// Generate all combinations of sizes, sparsities, and shifts
88-
for (int row : rows) {
89-
for (int col : cols) {
90-
for (double sparsity : sparsities) {
91-
for (int shift : shifts) {
92-
tests.add(new Object[]{row, col, sparsity, shift});
93-
}
94-
}
95-
}
96-
}
97-
return tests;
98-
}
87+
// Generate all combinations of sizes, sparsities, and shifts
88+
for (int row : rows) {
89+
for (int col : cols) {
90+
for (double sparsity : sparsities) {
91+
for (int shift : shifts) {
92+
tests.add(new Object[]{row, col, sparsity, shift});
93+
}
94+
}
95+
}
96+
}
97+
return tests;
98+
}
9999

100-
/**
101-
* The actual test method that performs the roll operation on both
102-
* sparse and dense matrices and compares the results.
103-
*/
104-
@Test
105-
public void test() {
106-
try {
107-
IndexFunction op = new RollIndex(shift);
108-
MatrixBlock outputDense = inputDense.reorgOperations(
100+
/**
101+
* The actual test method that performs the roll operation on both
102+
* sparse and dense matrices and compares the results.
103+
*/
104+
@Test
105+
public void test() {
106+
try {
107+
IndexFunction op = new RollIndex(shift);
108+
MatrixBlock outputDense = inputDense.reorgOperations(
109109
new ReorgOperator(op), new MatrixBlock(), 0, 0, 0);
110-
MatrixBlock outputSparse = inputSparse.reorgOperations(
110+
MatrixBlock outputSparse = inputSparse.reorgOperations(
111111
new ReorgOperator(op), new MatrixBlock(), 0, 0, 0);
112-
outputSparse.sparseToDense();
112+
outputSparse.sparseToDense();
113113

114-
// Compare the dense representations of both outputs
115-
TestUtils.compareMatrices(outputSparse, outputDense, 1e-9,
114+
// Compare the dense representations of both outputs
115+
TestUtils.compareMatrices(outputSparse, outputDense, 1e-9,
116116
"Compare Sparse and Dense Roll Results");
117117

118-
} catch (Exception e) {
119-
e.printStackTrace();
120-
fail("Exception occurred during roll function test: " + e.getMessage());
121-
}
122-
}
118+
} catch (Exception e) {
119+
e.printStackTrace();
120+
fail("Exception occurred during roll function test: " + e.getMessage());
121+
}
122+
}
123123
}

0 commit comments

Comments
 (0)