5
5
import json
6
6
7
7
import numpy as np
8
+ import jax
9
+ import pydantic .v1 as pd
8
10
9
11
from jax .tree_util import tree_flatten as jax_tree_flatten
10
12
from jax .tree_util import tree_unflatten as jax_tree_unflatten
16
18
class JaxObject (Tidy3dBaseModel ):
17
19
"""Abstract class that makes a :class:`.Tidy3dBaseModel` jax-compatible through inheritance."""
18
20
19
- """Shortcut to get names of all fields that have jax components."""
21
+ _tidy3d_class = Tidy3dBaseModel
22
+
23
+ """Shortcut to get names of fields with certain properties."""
20
24
21
25
@classmethod
22
- def get_jax_field_names (cls ) -> List [str ]:
23
- """Returns list of field names that have a ``jax_field_type ``."""
24
- adjoint_fields = []
26
+ def _get_field_names (cls , field_key : str ) -> List [str ]:
27
+ """Get all fields where ``field_key`` defined in the ``pydantic.Field ``."""
28
+ fields = []
25
29
for field_name , model_field in cls .__fields__ .items ():
26
- jax_field_type = model_field .field_info .extra .get ("jax_field" )
27
- if jax_field_type :
28
- adjoint_fields .append (field_name )
29
- return adjoint_fields
30
+ field_value = model_field .field_info .extra .get (field_key )
31
+ if field_value :
32
+ fields .append (field_name )
33
+ return fields
34
+
35
+ @classmethod
36
+ def get_jax_field_names (cls ) -> List [str ]:
37
+ """Returns list of field names where ``jax_field=True``."""
38
+ return cls ._get_field_names ("jax_field" )
39
+
40
+ @classmethod
41
+ def get_jax_leaf_names (cls ) -> List [str ]:
42
+ """Returns list of field names where ``stores_jax_for`` defined."""
43
+ return cls ._get_field_names ("stores_jax_for" )
44
+
45
+ @classmethod
46
+ def get_jax_field_names_all (cls ) -> List [str ]:
47
+ """Returns list of field names where ``jax_field=True`` or ``stores_jax_for`` defined."""
48
+ jax_field_names = cls .get_jax_field_names ()
49
+ jax_leaf_names = cls .get_jax_leaf_names ()
50
+ return list (set (jax_field_names + jax_leaf_names ))
51
+
52
+ @property
53
+ def jax_fields (self ) -> dict :
54
+ """Get dictionary of ``jax`` fields."""
55
+
56
+ # TODO: don't use getattr, define this dictionary better
57
+ jax_field_names = self .get_jax_field_names ()
58
+ return {key : getattr (self , key ) for key in jax_field_names }
30
59
31
60
"""Methods needed for jax to register arbitrary classes."""
32
61
33
62
def tree_flatten (self ) -> Tuple [list , dict ]:
34
- """How to flatten a :class:`.JaxObject` instance into a pytree."""
63
+ """How to flatten a :class:`.JaxObject` instance into a `` pytree`` ."""
35
64
children = []
36
65
aux_data = self .dict ()
37
- for field_name in self .get_jax_field_names ():
66
+
67
+ for field_name in self .get_jax_field_names_all ():
38
68
field = getattr (self , field_name )
39
69
sub_children , sub_aux_data = jax_tree_flatten (field )
40
70
children .append (sub_children )
41
71
aux_data [field_name ] = sub_aux_data
42
72
43
- def fix_polyslab (geo_dict : dict ) -> None :
44
- """Recursively Fix a dictionary possibly containing a polyslab geometry."""
45
- if geo_dict ["type" ] == "PolySlab" :
46
- vertices = geo_dict ["vertices" ]
47
- geo_dict ["vertices" ] = vertices .tolist ()
48
- elif geo_dict ["type" ] == "GeometryGroup" :
49
- for sub_geo_dict in geo_dict ["geometries" ]:
50
- fix_polyslab (sub_geo_dict )
51
- elif geo_dict ["type" ] == "ClipOperation" :
52
- fix_polyslab (geo_dict ["geometry_a" ])
53
- fix_polyslab (geo_dict ["geometry_b" ])
54
-
55
- def fix_monitor (mnt_dict : dict ) -> None :
56
- """Fix a frequency containing monitor."""
57
- if "freqs" in mnt_dict :
58
- freqs = mnt_dict ["freqs" ]
59
- if isinstance (freqs , np .ndarray ):
60
- mnt_dict ["freqs" ] = freqs .tolist ()
61
-
62
- # fixes bug with jax handling 2D numpy array in polyslab vertices
63
- if aux_data .get ("type" , "" ) == "JaxSimulation" :
64
- structures = aux_data ["structures" ]
65
- for _i , structure in enumerate (structures ):
66
- geometry = structure ["geometry" ]
67
- fix_polyslab (geometry )
68
- for monitor in aux_data ["monitors" ]:
69
- fix_monitor (monitor )
70
- for monitor in aux_data ["output_monitors" ]:
71
- fix_monitor (monitor )
73
+ def fix_numpy (value : Any ) -> Any :
74
+ """Recursively convert any ``numpy`` array in the value to nested list."""
75
+ if isinstance (value , (tuple , list )):
76
+ return [fix_numpy (val ) for val in value ]
77
+ if isinstance (value , np .ndarray ):
78
+ return value .tolist ()
79
+ if isinstance (value , dict ):
80
+ return {key : fix_numpy (val ) for key , val in value .items ()}
81
+ else :
82
+ return value
83
+
84
+ aux_data = fix_numpy (aux_data )
72
85
73
86
return children , aux_data
74
87
75
88
@classmethod
76
89
def tree_unflatten (cls , aux_data : dict , children : list ) -> JaxObject :
77
- """How to unflatten a pytree into a :class:`.JaxObject` instance."""
90
+ """How to unflatten a `` pytree`` into a :class:`.JaxObject` instance."""
78
91
self_dict = aux_data .copy ()
79
- for field_name , sub_children in zip (cls .get_jax_field_names (), children ):
92
+ for field_name , sub_children in zip (cls .get_jax_field_names_all (), children ):
80
93
sub_aux_data = aux_data [field_name ]
81
94
field = jax_tree_unflatten (sub_aux_data , sub_children )
82
95
self_dict [field_name ] = field
@@ -85,38 +98,110 @@ def tree_unflatten(cls, aux_data: dict, children: list) -> JaxObject:
85
98
86
99
"""Type conversion helpers."""
87
100
101
+ def to_tidy3d (self : JaxObject ) -> Tidy3dBaseModel :
102
+ """Convert :class:`.JaxObject` instance to :class:`.Tidy3dBaseModel` instance."""
103
+
104
+ self_dict = self .dict (exclude = self .exclude_fields_leafs_only )
105
+
106
+ for key in self .get_jax_field_names ():
107
+ sub_field = self .jax_fields [key ]
108
+
109
+ # TODO: simplify this logic
110
+ if isinstance (sub_field , (tuple , list )):
111
+ self_dict [key ] = [x .to_tidy3d () for x in sub_field ]
112
+ else :
113
+ self_dict [key ] = sub_field .to_tidy3d ()
114
+ # end TODO
115
+
116
+ return self ._tidy3d_class .parse_obj (self_dict )
117
+
88
118
@classmethod
89
119
def from_tidy3d (cls , tidy3d_obj : Tidy3dBaseModel ) -> JaxObject :
90
120
"""Convert :class:`.Tidy3dBaseModel` instance to :class:`.JaxObject`."""
91
121
obj_dict = tidy3d_obj .dict (exclude = {"type" })
122
+
123
+ for key in cls .get_jax_field_names ():
124
+ sub_field_type = cls .__fields__ [key ].type_
125
+ tidy3d_sub_field = getattr (tidy3d_obj , key )
126
+
127
+ # TODO: simplify this logic
128
+ if isinstance (tidy3d_sub_field , (tuple , list )):
129
+ obj_dict [key ] = [sub_field_type .from_tidy3d (x ) for x in tidy3d_sub_field ]
130
+ else :
131
+ obj_dict [key ] = sub_field_type .from_tidy3d (tidy3d_sub_field )
132
+ # end TODO
133
+
92
134
return cls .parse_obj (obj_dict )
93
135
136
+ @property
137
+ def exclude_fields_leafs_only (self ) -> set :
138
+ """Fields to exclude from ``self.dict()``, ``"type"`` and all ``jax`` leafs."""
139
+ return set (["type" ] + self .get_jax_leaf_names ())
140
+
141
+ """Accounting with jax and regular fields."""
142
+
143
+ @pd .root_validator (pre = True )
144
+ def handle_jax_kwargs (cls , values : dict ) -> dict :
145
+ """Pass jax inputs to the jax fields and pass untraced values to the regular fields."""
146
+
147
+ # for all jax-traced fields
148
+ for jax_name in cls .get_jax_leaf_names ():
149
+ # if a value was passed to the object for the regular field
150
+ orig_name = cls .__fields__ [jax_name ].field_info .extra .get ("stores_jax_for" )
151
+ val = values .get (orig_name )
152
+ if val is not None :
153
+
154
+ # try adding the sanitized (no trace) version to the regular field
155
+ try :
156
+ values [orig_name ] = jax .lax .stop_gradient (val )
157
+
158
+ # if it doesnt work, just pass the raw value (necessary to handle inf strings)
159
+ except TypeError :
160
+ values [orig_name ] = val
161
+
162
+ # if the jax name was not specified directly, use the original traced value
163
+ if jax_name not in values :
164
+ values [jax_name ] = val
165
+
166
+ return values
167
+
168
+ @pd .root_validator (pre = True )
169
+ def handle_array_jax_leafs (cls , values : dict ) -> dict :
170
+ """Convert jax_leafs that are passed as numpy arrays."""
171
+ for jax_name in cls .get_jax_leaf_names ():
172
+ val = values .get (jax_name )
173
+ if isinstance (val , np .ndarray ):
174
+ values [jax_name ] = val .tolist ()
175
+ return values
176
+
94
177
""" IO """
95
178
179
+ # TODO: replace with JaxObject json encoder
180
+
96
181
def _json (self , * args , ** kwargs ) -> str :
97
182
"""Overwritten method to get the json string to store in the files."""
98
183
99
184
json_string_og = super ()._json (* args , ** kwargs )
100
185
json_dict = json .loads (json_string_og )
101
186
102
- def strip_data_array (sub_dict : dict ) -> None :
103
- """Strip any elements of the dictionary with type "JaxDataArray", replace with tag."""
187
+ def strip_data_array (val : Any ) -> Any :
188
+ """Recursively strip any elements with type "JaxDataArray", replace with tag."""
104
189
105
- for key , val in sub_dict .items ():
190
+ if isinstance (val , dict ):
191
+ if "type" in val and val ["type" ] == "JaxDataArray" :
192
+ return JAX_DATA_ARRAY_TAG
193
+ return {k : strip_data_array (v ) for k , v in val .items ()}
106
194
107
- if isinstance (val , dict ):
108
- if "type" in val and val ["type" ] == "JaxDataArray" :
109
- sub_dict [key ] = JAX_DATA_ARRAY_TAG
110
- else :
111
- strip_data_array (val )
112
- elif isinstance (val , (list , tuple )):
113
- val_dict = dict (zip (range (len (val )), val ))
114
- strip_data_array (val_dict )
115
- sub_dict [key ] = list (val_dict .values ())
195
+ elif isinstance (val , (tuple , list )):
196
+ return [strip_data_array (v ) for v in val ]
116
197
117
- strip_data_array (json_dict )
198
+ return val
199
+
200
+ json_dict = strip_data_array (json_dict )
118
201
return json .dumps (json_dict )
119
202
203
+ # TODO: replace with implementing these in DataArray
204
+
120
205
def to_hdf5 (self , fname : str , custom_encoders : List [Callable ] = None ) -> None :
121
206
"""Exports :class:`JaxObject` instance to .hdf5 file.
122
207
0 commit comments