Skip to content

Commit 6630bd8

Browse files
committed
test: add io tests
1 parent c0ac379 commit 6630bd8

File tree

1 file changed

+170
-0
lines changed

1 file changed

+170
-0
lines changed

brainpy/base/tests/test_io.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# -*- coding: utf-8 -*-
2+
3+
4+
import brainpy as bp
5+
import brainpy.math as bm
6+
import unittest
7+
8+
9+
class TestIO1(unittest.TestCase):
10+
def __init__(self, *args, **kwargs):
11+
super(TestIO1, self).__init__(*args, **kwargs)
12+
13+
rng = bm.random.RandomState()
14+
15+
class IO1(bp.dyn.DynamicalSystem):
16+
def __init__(self):
17+
super(IO1, self).__init__()
18+
19+
self.a = bm.Variable(bm.zeros(1))
20+
self.b = bm.Variable(bm.ones(3))
21+
self.c = bm.Variable(bm.ones((3, 4)))
22+
self.d = bm.Variable(bm.ones((2, 3, 4)))
23+
24+
class IO2(bp.dyn.DynamicalSystem):
25+
def __init__(self):
26+
super(IO2, self).__init__()
27+
28+
self.a = bm.Variable(rng.rand(3))
29+
self.b = bm.Variable(rng.randn(10))
30+
31+
io1 = IO1()
32+
io2 = IO2()
33+
io1.a2 = io2.a
34+
io1.b2 = io2.b
35+
io2.a2 = io1.a
36+
io2.b2 = io2.b
37+
38+
self.net = bp.dyn.Container(io1, io2)
39+
40+
print(self.net.vars().keys())
41+
print(self.net.vars().unique().keys())
42+
43+
def test_h5(self):
44+
bp.base.save_as_h5('io_test_tmp.h5', self.net.vars())
45+
bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
46+
47+
bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars())
48+
bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
49+
50+
def test_h5_postfix(self):
51+
with self.assertRaises(ValueError):
52+
bp.base.save_as_h5('io_test_tmp.h52', self.net.vars())
53+
with self.assertRaises(ValueError):
54+
bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
55+
56+
def test_npz(self):
57+
bp.base.save_as_npz('io_test_tmp.npz', self.net.vars())
58+
bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True)
59+
60+
bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
61+
bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)
62+
63+
def test_npz_postfix(self):
64+
with self.assertRaises(ValueError):
65+
bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars())
66+
with self.assertRaises(ValueError):
67+
bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)
68+
69+
def test_pkl(self):
70+
bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars())
71+
bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)
72+
73+
bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars())
74+
bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)
75+
76+
def test_pkl_postfix(self):
77+
with self.assertRaises(ValueError):
78+
bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
79+
with self.assertRaises(ValueError):
80+
bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)
81+
82+
def test_mat(self):
83+
bp.base.save_as_mat('io_test_tmp.mat', self.net.vars())
84+
bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True)
85+
86+
def test_mat_postfix(self):
87+
with self.assertRaises(ValueError):
88+
bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars())
89+
with self.assertRaises(ValueError):
90+
bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)
91+
92+
93+
class TestIO2(unittest.TestCase):
94+
def __init__(self, *args, **kwargs):
95+
super(TestIO2, self).__init__(*args, **kwargs)
96+
97+
rng = bm.random.RandomState()
98+
99+
class IO1(bp.dyn.DynamicalSystem):
100+
def __init__(self):
101+
super(IO1, self).__init__()
102+
103+
self.a = bm.Variable(bm.zeros(1))
104+
self.b = bm.Variable(bm.ones(3))
105+
self.c = bm.Variable(bm.ones((3, 4)))
106+
self.d = bm.Variable(bm.ones((2, 3, 4)))
107+
108+
class IO2(bp.dyn.DynamicalSystem):
109+
def __init__(self):
110+
super(IO2, self).__init__()
111+
112+
self.a = bm.Variable(rng.rand(3))
113+
self.b = bm.Variable(rng.randn(10))
114+
115+
io1 = IO1()
116+
io2 = IO2()
117+
118+
self.net = bp.dyn.Container(io1, io2)
119+
120+
print(self.net.vars().keys())
121+
print(self.net.vars().unique().keys())
122+
123+
def test_h5(self):
124+
bp.base.save_as_h5('io_test_tmp.h5', self.net.vars())
125+
bp.base.load_by_h5('io_test_tmp.h5', self.net, verbose=True)
126+
127+
bp.base.save_as_h5('io_test_tmp.hdf5', self.net.vars())
128+
bp.base.load_by_h5('io_test_tmp.hdf5', self.net, verbose=True)
129+
130+
def test_h5_postfix(self):
131+
with self.assertRaises(ValueError):
132+
bp.base.save_as_h5('io_test_tmp.h52', self.net.vars())
133+
with self.assertRaises(ValueError):
134+
bp.base.load_by_h5('io_test_tmp.h52', self.net, verbose=True)
135+
136+
def test_npz(self):
137+
bp.base.save_as_npz('io_test_tmp.npz', self.net.vars())
138+
bp.base.load_by_npz('io_test_tmp.npz', self.net, verbose=True)
139+
140+
bp.base.save_as_npz('io_test_tmp_compressed.npz', self.net.vars(), compressed=True)
141+
bp.base.load_by_npz('io_test_tmp_compressed.npz', self.net, verbose=True)
142+
143+
def test_npz_postfix(self):
144+
with self.assertRaises(ValueError):
145+
bp.base.save_as_npz('io_test_tmp.npz2', self.net.vars())
146+
with self.assertRaises(ValueError):
147+
bp.base.load_by_npz('io_test_tmp.npz2', self.net, verbose=True)
148+
149+
def test_pkl(self):
150+
bp.base.save_as_pkl('io_test_tmp.pkl', self.net.vars())
151+
bp.base.load_by_pkl('io_test_tmp.pkl', self.net, verbose=True)
152+
153+
bp.base.save_as_pkl('io_test_tmp.pickle', self.net.vars())
154+
bp.base.load_by_pkl('io_test_tmp.pickle', self.net, verbose=True)
155+
156+
def test_pkl_postfix(self):
157+
with self.assertRaises(ValueError):
158+
bp.base.save_as_pkl('io_test_tmp.pkl2', self.net.vars())
159+
with self.assertRaises(ValueError):
160+
bp.base.load_by_pkl('io_test_tmp.pkl2', self.net, verbose=True)
161+
162+
def test_mat(self):
163+
bp.base.save_as_mat('io_test_tmp.mat', self.net.vars())
164+
bp.base.load_by_mat('io_test_tmp.mat', self.net, verbose=True)
165+
166+
def test_mat_postfix(self):
167+
with self.assertRaises(ValueError):
168+
bp.base.save_as_mat('io_test_tmp.mat2', self.net.vars())
169+
with self.assertRaises(ValueError):
170+
bp.base.load_by_mat('io_test_tmp.mat2', self.net, verbose=True)

0 commit comments

Comments
 (0)