@@ -89,77 +89,25 @@ def run(self, source: Dataset) -> Dataset: # noqa: D102
8989 :param source: The source dataset.
9090 :return: The result dataset.
9191 """
92+ config : dict [str : dict [str :Any ]] = self .config .get (
93+ self ._args .source_type , {}
94+ )
9295 target : Dataset = Dataset (
9396 data_vars = source .data_vars ,
9497 coords = source .coords ,
9598 attrs = source .attrs ,
9699 )
97- config : dict [str : dict [str :Any ]] = self .config .get (
98- self ._args .source_type , {}
99- )
100100 for v , x in target .data_vars .items ():
101101 if v not in config or self ._args .selector == 0 :
102102 continue
103103 get_logger ().info (f"starting graph for variable: { v } " )
104- s : list [int ] = self .entropy (v , self .uuid )
105- a : dict [str :Any ] = config [v ]
106- f = Randomize (m = x .ndim , dist = a ["distribution" ], entropy = s )
107- if "uncertainty" in a :
108- u = (
109- target [a ["uncertainty" ]]
110- if isinstance (a ["uncertainty" ], str )
111- else DataArray (
112- data = da .full (
113- x .shape , a ["uncertainty" ], chunks = x .chunks
114- ),
115- coords = x .coords ,
116- dims = x .dims ,
117- attrs = {},
118- )
119- )
120- z = f .apply_to (
121- _decode (x .data , x .attrs ),
122- _decode (u .data , u .attrs ),
123- coverage = a .get ("coverage" , 1.0 ),
124- relative = a .get ("relative" , False ),
125- clip = a .get ("clip" , None ),
126- )
127- else :
128- b = target [a ["bias" ]]
129- r = target [a ["rmsd" ]]
130- z = f .apply_to (
131- _decode (x .data , x .attrs ),
132- _decode (r .data , r .attrs ),
133- _decode (b .data , b .attrs ),
134- clip = a .get ("clip" , None ),
135- )
136- target [v ] = DataArray (
137- data = _encode (z , x .attrs , x .dtype ),
138- coords = x .coords ,
139- dims = x .dims ,
140- attrs = x .attrs ,
141- )
142- if "actual_range" in target [v ].attrs :
143- target [v ].attrs ["actual_range" ] = np .array (
144- [
145- da .nanmin (z ).compute (),
146- da .nanmax (z ).compute (),
147- ],
148- dtype = z .dtype ,
149- )
150- target [v ].attrs ["entropy" ] = np .array (s , dtype = np .int64 )
151- if get_logger ().is_enabled (Logging .DEBUG ):
152- get_logger ().debug (f"entropy: { s } " )
153- get_logger ().debug (f"min: { da .nanmin (z ).compute () :.3f} " )
154- get_logger ().debug (f"max: { da .nanmax (z ).compute () :.3f} " )
155- get_logger ().debug (f"mean: { da .nanmean (z ).compute () :.3f} " )
156- get_logger ().debug (f"std: { da .nanstd (z ).compute () :.3f} " )
104+ self .randomize (source , target , v , x , config [v ])
157105 get_logger ().info (f"finished graph for variable: { v } " )
158106 return target
159107
160108 @property
161109 def config (self ) -> dict [str : dict [str :Any ]]:
162- """Returns the product type configuration."""
110+ """Returns the randomization configuration."""
163111 package = "kaleidoscope.config"
164112 name = "config.random.json"
165113 with resources .path (package , name ) as resource :
@@ -187,6 +135,86 @@ def entropy(self, name: str, uuid: str, n: int = 4) -> list[int]:
187135 g = DefaultGenerator (Philox (seed ))
188136 return [g .next () for _ in range (n )]
189137
138+ def randomize (
139+ self ,
140+ source : Dataset ,
141+ target : Dataset ,
142+ v : str ,
143+ x : DataArray ,
144+ config : dict [str :Any ],
145+ ):
146+ """
147+ Creates the graph to randomize a variable.
148+
149+ :param source: The source dataset.
150+ :param target: The target dataset.
151+ :param v: The name of the variable.
152+ :param x: The data of the variable.
153+ :param config: The randomization configuration.
154+ """
155+ if "total" in config :
156+ s : list [int ] = []
157+ z = _decode (x .data , x .attrs )
158+ for ref in config ["total" ]:
159+ a = _decode (target [ref ].data , target [ref ].attrs )
160+ b = _decode (source [ref ].data , source [ref ].attrs )
161+ z = z + (a - b )
162+ elif "uncertainty" in config :
163+ s : list [int ] = self .entropy (v , self .uuid )
164+ f = Randomize (m = x .ndim , dist = config ["distribution" ], entropy = s )
165+ u = (
166+ target [config ["uncertainty" ]]
167+ if isinstance (config ["uncertainty" ], str )
168+ else DataArray (
169+ data = da .full (
170+ x .shape , config ["uncertainty" ], chunks = x .chunks
171+ ),
172+ coords = x .coords ,
173+ dims = x .dims ,
174+ attrs = {},
175+ )
176+ )
177+ z = f .apply_to (
178+ _decode (x .data , x .attrs ),
179+ _decode (u .data , u .attrs ),
180+ coverage = config .get ("coverage" , 1.0 ),
181+ relative = config .get ("relative" , False ),
182+ clip = config .get ("clip" , None ),
183+ )
184+ else :
185+ s : list [int ] = self .entropy (v , self .uuid )
186+ f = Randomize (m = x .ndim , dist = config ["distribution" ], entropy = s )
187+ b = target [config ["bias" ]]
188+ r = target [config ["rmsd" ]]
189+ z = f .apply_to (
190+ _decode (x .data , x .attrs ),
191+ _decode (r .data , r .attrs ),
192+ _decode (b .data , b .attrs ),
193+ clip = config .get ("clip" , None ),
194+ )
195+ target [v ] = DataArray (
196+ data = _encode (z , x .attrs , x .dtype ),
197+ coords = x .coords ,
198+ dims = x .dims ,
199+ attrs = x .attrs ,
200+ )
201+ if "actual_range" in target [v ].attrs :
202+ target [v ].attrs ["actual_range" ] = np .array (
203+ [
204+ da .nanmin (z ).compute (),
205+ da .nanmax (z ).compute (),
206+ ],
207+ dtype = z .dtype ,
208+ )
209+ if s :
210+ target [v ].attrs ["entropy" ] = np .array (s , dtype = np .int64 )
211+ if get_logger ().is_enabled (Logging .DEBUG ):
212+ get_logger ().debug (f"entropy: { s } " )
213+ get_logger ().debug (f"min: { da .nanmin (z ).compute () :.3f} " )
214+ get_logger ().debug (f"max: { da .nanmax (z ).compute () :.3f} " )
215+ get_logger ().debug (f"mean: { da .nanmean (z ).compute () :.3f} " )
216+ get_logger ().debug (f"std: { da .nanstd (z ).compute () :.3f} " )
217+
190218 @property
191219 def uuid (self ) -> str :
192220 """
0 commit comments