1
- """Tests for global_circulation.structs."""
2
-
1
+ # Copyright 2021 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import pickle
3
15
from typing import Union
4
16
5
17
from absl .testing import absltest
10
22
import numpy as np
11
23
import tree_math
12
24
25
+ from tree_math ._src import test_util
26
+
27
+
13
28
ArrayLike = Union [jnp .ndarray , np .ndarray , float ]
14
29
15
30
@@ -19,7 +34,7 @@ class TestStruct:
19
34
b : ArrayLike
20
35
21
36
22
- class StructsTest (parameterized .TestCase ):
37
+ class StructsTest (test_util .TestCase ):
23
38
24
39
@parameterized .named_parameters (
25
40
dict (testcase_name = 'Scalars' , x = TestStruct (1. , 2. )),
@@ -89,6 +104,11 @@ def testJit(self, x, y, operation):
89
104
np .testing .assert_allclose (jitted .a , unjitted .a )
90
105
np .testing .assert_allclose (jitted .b , unjitted .b )
91
106
107
+ def testPickle (self ):
108
+ struct = TestStruct (1 , 2 )
109
+ restored = pickle .loads (pickle .dumps (struct ))
110
+ self .assertTreeEqual (struct , restored , check_dtypes = True )
111
+
92
112
93
113
if __name__ == '__main__' :
94
114
absltest .main ()
0 commit comments