Skip to content

Commit 0ae059e

Browse files
shoyertree-math authors
authored andcommitted
Fix pickling of tree_math.struct
PiperOrigin-RevId: 529802186
1 parent 0af9679 commit 0ae059e

File tree

2 files changed

+38
-4
lines changed

2 files changed

+38
-4
lines changed

tree_math/_src/structs.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,16 @@
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
114
"""Helpers for constructing data classes that are JAX and tree-math enabled."""
215

316
import dataclasses
@@ -62,5 +75,6 @@ def tree_unflatten(cls, _, children):
6275
'asdict': asdict,
6376
'astuple': astuple,
6477
'tree_flatten': tree_flatten,
65-
'tree_unflatten': tree_unflatten})
78+
'tree_unflatten': tree_unflatten,
79+
'__module__': cls.__module__})
6680
return jax.tree_util.register_pytree_node_class(cls_as_struct)

tree_math/_src/structs_test.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
1-
"""Tests for global_circulation.structs."""
2-
1+
# Copyright 2021 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import pickle
315
from typing import Union
416

517
from absl.testing import absltest
@@ -10,6 +22,9 @@
1022
import numpy as np
1123
import tree_math
1224

25+
from tree_math._src import test_util
26+
27+
1328
ArrayLike = Union[jnp.ndarray, np.ndarray, float]
1429

1530

@@ -19,7 +34,7 @@ class TestStruct:
1934
b: ArrayLike
2035

2136

22-
class StructsTest(parameterized.TestCase):
37+
class StructsTest(test_util.TestCase):
2338

2439
@parameterized.named_parameters(
2540
dict(testcase_name='Scalars', x=TestStruct(1., 2.)),
@@ -89,6 +104,11 @@ def testJit(self, x, y, operation):
89104
np.testing.assert_allclose(jitted.a, unjitted.a)
90105
np.testing.assert_allclose(jitted.b, unjitted.b)
91106

107+
def testPickle(self):
108+
struct = TestStruct(1, 2)
109+
restored = pickle.loads(pickle.dumps(struct))
110+
self.assertTreeEqual(struct, restored, check_dtypes=True)
111+
92112

93113
if __name__ == '__main__':
94114
absltest.main()

0 commit comments

Comments
 (0)