Skip to content

Commit 803abdf

Browse files
committed
Disable casting of floats in evaluate() to ints unless cast_to_int set
1 parent 0eadbb0 commit 803abdf

File tree

4 files changed

+81
-30
lines changed

4 files changed

+81
-30
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ name: Continuous builds
22

33
on:
44
push:
5-
branches: [ main, development, experimental ]
5+
branches: [ main, development, experimental, test* ]
66
pull_request:
7-
branches: [ main, development, experimental ]
7+
branches: [ main, development, experimental, test* ]
88

99
jobs:
1010

src/modelspec/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.2.8"
1+
__version__ = "0.2.9"
22

33
from .base_types import Base, define, has, field, fields, optional, instance_of, in_
44

src/modelspec/utils.py

Lines changed: 56 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,18 @@
99
from modelspec.base_types import print_
1010
from modelspec.base_types import EvaluableExpression
1111

12+
from random import Random
13+
from typing import Union
14+
1215
verbose = False
1316

1417

15-
def load_json(filename):
18+
def load_json(filename: str):
1619
"""
1720
Load a generic JSON file
21+
22+
Args:
23+
filename: The name of the JSON file to load
1824
"""
1925

2026
with open(filename) as f:
@@ -23,19 +29,25 @@ def load_json(filename):
2329
return data
2430

2531

26-
def load_yaml(filename):
32+
def load_yaml(filename: str):
2733
"""
2834
Load a generic YAML file
35+
36+
Args:
37+
filename: The name of the YAML file to load
2938
"""
3039
with open(filename) as f:
3140
data = yaml.load(f, Loader=yaml.SafeLoader)
3241

3342
return data
3443

3544

36-
def load_bson(filename):
45+
def load_bson(filename: str):
3746
"""
3847
Load a generic BSON file
48+
49+
Args:
50+
filename: The name of the BSON file to load
3951
"""
4052
with open(filename, "rb") as infile:
4153
data_encoded = infile.read()
@@ -211,11 +223,26 @@ def _params_info(parameters, multiline=False):
211223
FORMAT_TENSORFLOW = "tensorflow"
212224

213225

214-
def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=False):
226+
def evaluate(
227+
expr: Union[int, float, str, list, dict],
228+
parameters: dict = {},
229+
rng: Random = None,
230+
array_format: str = FORMAT_NUMPY,
231+
verbose: bool = False,
232+
cast_to_int: bool = False,
233+
):
215234
"""
216235
Evaluate a general string like expression (e.g. "2 * weight") using a dict
217236
of parameters (e.g. {'weight':10}). Returns floats, ints, etc. if that's what's
218237
given in expr
238+
239+
Args:
240+
expr: The expression to convert
241+
parameters: A dict of the parameters which can be substituted in to the expression
242+
rng: The random number generator to use
243+
array_format: numpy or tensorflow
244+
verbose: Print the calculations
245+
cast_to_int: return an int for float/string values if castable
219246
"""
220247

221248
if array_format == FORMAT_TENSORFLOW:
@@ -233,35 +260,39 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
233260
expr
234261
] # replace with the value in parameters & check whether it's float/int...
235262
if verbose:
236-
print_("Using for that param: %s" % _val_info(expr), verbose)
263+
print_(" Using for that param: %s" % _val_info(expr), verbose)
237264

238265
if type(expr) == str:
239266
try:
267+
print(1)
240268
if array_format == FORMAT_TENSORFLOW:
241269
expr = tf.constant(int(expr))
242270
else:
243271
expr = int(expr)
272+
print(2)
244273
except:
245-
pass
246-
try:
247-
if array_format == FORMAT_TENSORFLOW:
248-
expr = tf.constant(float(expr))
249-
else:
250-
expr = float(expr)
251-
except:
252-
pass
274+
275+
try:
276+
if array_format == FORMAT_TENSORFLOW:
277+
expr = tf.constant(float(expr))
278+
else:
279+
expr = float(expr)
280+
except:
281+
pass
253282

254283
if type(expr) == list:
255284
if verbose:
256-
print_("Returning a list in format: %s" % array_format, verbose)
285+
print_(" Returning a list in format: %s" % array_format, verbose)
257286
if array_format == FORMAT_TENSORFLOW:
258287
return tf.constant(expr, dtype=tf.float64)
259288
else:
260289
return np.array(expr)
261290

262291
if type(expr) == np.ndarray:
263292
if verbose:
264-
print_("Returning a numpy array in format: %s" % array_format, verbose)
293+
print_(
294+
" Returning a numpy array in format: %s" % array_format, verbose
295+
)
265296
if array_format == FORMAT_TENSORFLOW:
266297
return tf.convert_to_tensor(expr, dtype=tf.float64)
267298
else:
@@ -270,22 +301,22 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
270301
if "Tensor" in type(expr).__name__:
271302
if verbose:
272303
print_(
273-
"Returning a tensorflow Tensor in format: %s" % array_format,
304+
" Returning a tensorflow Tensor in format: %s" % array_format,
274305
verbose,
275306
)
276307
if array_format == FORMAT_NUMPY:
277308
return expr.numpy()
278309
else:
279310
return expr
280311

281-
if int(expr) == expr:
312+
if int(expr) == expr and cast_to_int:
282313
if verbose:
283-
print_("Returning int: %s" % int(expr), verbose)
314+
print_(" Returning int: %s" % int(expr), verbose)
284315
return int(expr)
285316
else: # will have failed if not number
286317
if verbose:
287-
print_("Returning float: %s" % expr, verbose)
288-
return float(expr)
318+
print_(" Returning {}: {}".format(type(expr), expr), verbose)
319+
return expr
289320
except:
290321
try:
291322
if rng:
@@ -299,7 +330,7 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
299330

300331
if verbose:
301332
print_(
302-
"Trying to eval [%s] with Python using %s..."
333+
" Trying to eval [%s] with Python using %s..."
303334
% (expr, parameters.keys()),
304335
verbose,
305336
)
@@ -308,13 +339,14 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
308339

309340
if verbose:
310341
print_(
311-
"Evaluated with Python: {} = {}".format(expr, _val_info(v)), verbose
342+
" Evaluated with Python: {} = {}".format(expr, _val_info(v)),
343+
verbose,
312344
)
313345

314346
if (type(v) == float or type(v) == str) and int(v) == v:
315347

316348
if verbose:
317-
print_("Returning int: %s" % int(v), verbose)
349+
print_(" Returning int: %s" % int(v), verbose)
318350

319351
if array_format == FORMAT_TENSORFLOW:
320352
return tf.constant(int(v))
@@ -323,7 +355,7 @@ def evaluate(expr, parameters={}, rng=None, array_format=FORMAT_NUMPY, verbose=F
323355
return v
324356
except Exception as e:
325357
if verbose:
326-
print_(f"Returning without altering: {expr} (error: {e})", verbose)
358+
print_(f" Returning without altering: {expr} (error: {e})", verbose)
327359
return expr
328360

329361

tests/test_utils.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,16 @@ def test_evaluate(self):
1919
params = {"p": 33}
2020
assert evaluate("p+p", params, verbose=True) == 66
2121

22+
print("======")
2223
assert type(evaluate("33")) == int
23-
assert type(evaluate("33.0")) == int
24+
assert type(evaluate("33", cast_to_int=True)) == int
25+
assert type(evaluate("33.0")) == float
26+
assert type(evaluate("33.0", cast_to_int=True)) == int
27+
2428
assert type(evaluate("33.1")) == float
25-
assert type(evaluate("33.1a")) == str
29+
assert type(evaluate("33.1a", verbose=True)) == str
2630

27-
assert type(evaluate("33.1a")) == str
31+
assert type(evaluate("a")) == str
2832

2933
import random
3034

@@ -41,6 +45,21 @@ def test_evaluate(self):
4145

4246
assert evaluate("a+b", params, verbose=True)[2] == 3
4347

48+
params = {"a1": np.array([1]), "b": np.array([1, 1, 3])}
49+
50+
a1_b = evaluate("a1+b", params, verbose=True)
51+
assert a1_b[2] == 4
52+
53+
params = {"A": np.ones([2, 2]), "B": np.ones([2, 2])}
54+
55+
AplusB = evaluate("A+B", params, verbose=True)
56+
assert AplusB[0, 0] == 2
57+
assert AplusB.shape == (2, 2)
58+
59+
AtimesB = evaluate("A*B", params, verbose=True)
60+
assert AtimesB[0, 0] == 1
61+
assert AtimesB.shape == (2, 2)
62+
4463
def test_val_info_tuple(self):
4564
print(_val_info((1, 2)))
4665
print(_val_info((("test", 1), 2)))

0 commit comments

Comments
 (0)