Skip to content

Commit 4d163a3

Browse files
committed
Added to_dict and from_dict methods to OptimResults object (need to make OptimResults public to give access to from_dict method)
1 parent e8f0a7f commit 4d163a3

File tree

3 files changed

+61
-3
lines changed

3 files changed

+61
-3
lines changed

dfols/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,5 +43,5 @@
4343

4444
# Main solver & exit flags
4545
from .solver import *
46-
__all__ = ['solve']
46+
__all__ = ['solve', 'OptimResults']
4747

dfols/solver.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from math import sqrt
3333
import numpy as np
3434
import os
35+
import pandas as pd
3536
import scipy.linalg as LA
3637
import scipy.stats as STAT
3738
import warnings
@@ -41,7 +42,7 @@
4142
from .params import *
4243
from .util import *
4344

44-
__all__ = ['solve']
45+
__all__ = ['solve', 'OptimResults']
4546

4647
module_logger = logging.getLogger(__name__)
4748

@@ -102,6 +103,50 @@ def __str__(self):
102103
output += "%s\n" % self.msg
103104
output += "****************************\n"
104105
return output
106+
107+
def to_dict(self, replace_nan=True):
108+
# Convert to a serializable dict object suitable for saving in a json file
109+
# If replace_nan=True, convert all NaN entries to None
110+
soln_dict = {}
111+
soln_dict['x'] = self.x.tolist() if self.x is not None else None
112+
soln_dict['resid'] = self.resid.tolist() if self.resid is not None else None
113+
soln_dict['obj'] = float(self.obj)
114+
soln_dict['jacobian'] = self.jacobian.tolist() if self.jacobian is not None else None
115+
soln_dict['nf'] = int(self.nf)
116+
soln_dict['nx'] = int(self.nx)
117+
soln_dict['nruns'] = int(self.nruns)
118+
soln_dict['flag'] = int(self.flag)
119+
soln_dict['msg'] = str(self.msg)
120+
soln_dict['diagnostic_info'] = self.diagnostic_info.to_dict() if self.diagnostic_info is not None else None
121+
soln_dict['xmin_eval_num'] = int(self.xmin_eval_num)
122+
soln_dict['jacmin_eval_nums'] = self.jacmin_eval_nums.tolist() if self.jacmin_eval_nums is not None else None
123+
if replace_nan:
124+
return replace_nan_with_none(soln_dict)
125+
else:
126+
return soln_dict
127+
128+
@staticmethod
129+
def from_dict(soln_dict):
130+
# Take a dict object containing OptimResults information, and return the relevant OptimResults object
131+
# Input soln_dict should come from soln.to_dict()
132+
# Note: np.array(mylist, dtype=float) automatically converts None to NaN
133+
x = np.array(soln_dict['x'], dtype=float) if soln_dict['x'] is not None else None
134+
resid = np.array(soln_dict['resid'], dtype=float) if soln_dict['resid'] is not None else None
135+
obj = soln_dict['obj']
136+
jacobian = np.array(soln_dict['jacobian'], dtype=float) if soln_dict['jacobian'] is not None else None
137+
nf = soln_dict['nf']
138+
nx = soln_dict['nx']
139+
nruns = soln_dict['nruns']
140+
flag = soln_dict['flag']
141+
msg = soln_dict['msg']
142+
xmin_eval_num = soln_dict['xmin_eval_num']
143+
jacmin_eval_nums = np.array(soln_dict['jacmin_eval_nums'], dtype=int) if soln_dict['jacmin_eval_nums'] is not None else None
144+
145+
soln = OptimResults(x, resid, obj, jacobian, nf, nx, nruns, flag, msg, xmin_eval_num, jacmin_eval_nums)
146+
147+
if soln_dict['diagnostic_info'] is not None:
148+
soln.diagnostic_info = pd.DataFrame.from_dict(soln_dict['diagnostic_info'])
149+
return soln
105150

106151

107152
def solve_main(objfun, x0, argsf, xl, xu, projections, npt, rhobeg, rhoend, maxfun, nruns_so_far, nf_so_far, nx_so_far, nsamples, params,

dfols/util.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
from __future__ import absolute_import, division, print_function, unicode_literals
2727

2828
import logging
29+
import math
2930
import numpy as np
3031
import scipy.linalg as LA
3132
import sys
3233

3334

3435
__all__ = ['sumsq', 'eval_least_squares_with_regularisation', 'model_value', 'random_orthog_directions_within_bounds',
35-
'random_directions_within_bounds', 'apply_scaling', 'remove_scaling', 'pbox', 'pball', 'dykstra', 'qr_rank']
36+
'random_directions_within_bounds', 'apply_scaling', 'remove_scaling', 'pbox', 'pball', 'dykstra', 'qr_rank', 'replace_nan_with_none']
3637

3738
module_logger = logging.getLogger(__name__)
3839

@@ -268,3 +269,15 @@ def qr_rank(A,tol=1e-15):
268269
D = np.abs(np.diag(R))
269270
rank = np.sum(D > tol)
270271
return rank, D
272+
273+
274+
def replace_nan_with_none(d):
275+
# Replace Nan values in a dict/list with None (used for JSON serializing of OptimResults object)
276+
if isinstance(d, dict):
277+
return {k: replace_nan_with_none(v) for k, v in d.items()}
278+
elif isinstance(d, list):
279+
return [replace_nan_with_none(i) for i in d]
280+
elif isinstance(d, float) and math.isnan(d):
281+
return None
282+
else:
283+
return d

0 commit comments

Comments
 (0)