Skip to content

Commit f1f997f

Browse files
authored
Merge pull request #33 from bbcho/v0285
build for numpy 2.0
2 parents a563661 + d2118cd commit f1f997f

File tree

4 files changed

+267
-253
lines changed

4 files changed

+267
-253
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def get_ext_filename(self, ext_name):
5050

5151
setuptools.setup(
5252
name="risktools",
53-
version="0.2.8.4",
53+
version="0.2.8.5",
5454
author="Ben Cho",
5555
license="gpl-3.0", # Chose a license from here: https://help.github.com/articles/licensing-a-repository
5656
author_email="[email protected]",

src/risktools/_main_functions.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
register_matplotlib_converters as _register_matplotlib_converters,
1616
)
1717
import seaborn as _sns
18+
import time
1819

1920
from ._pa import *
2021

@@ -284,14 +285,18 @@ def returns(df, ret_type="abs", period_return=1, spread=False):
284285
if ret_type == "abs":
285286
df = df.groupby(level=0, group_keys=False).apply(lambda x: x.diff())
286287
elif ret_type == "rel":
287-
df = df.groupby(level=0, group_keys=False).apply(lambda x: x / x.shift(period_return) - 1)
288+
df = df.groupby(level=0, group_keys=False).apply(
289+
lambda x: x / x.shift(period_return) - 1
290+
)
288291
elif ret_type == "log":
289292
if df[df < 0].count().sum() > 0:
290293
warnings.warn(
291294
"Negative values passed to log returns. You will likely get NaN values using log returns",
292295
RuntimeWarning,
293296
)
294-
df = df.groupby(level=0, group_keys=False).apply(lambda x: _np.log(x / x.shift(period_return)))
297+
df = df.groupby(level=0, group_keys=False).apply(
298+
lambda x: _np.log(x / x.shift(period_return))
299+
)
295300
else:
296301
raise ValueError("ret_type is not valid")
297302

@@ -704,7 +709,7 @@ def crr_euro(s=100, x=100, sigma=0.2, Rf=0.1, T=1, n=5, type="call"):
704709

705710
for i in range(0, n + 1):
706711
for j in range(0, i + 1):
707-
asset[i, j] = s * (u ** j) * (d ** (i - j))
712+
asset[i, j] = s * (u**j) * (d ** (i - j))
708713

709714
# create matrix of the same dims as asset price tree
710715
option = _np.zeros([n + 1, n + 1])
@@ -809,7 +814,7 @@ def stl_decomposition(
809814
return res
810815

811816

812-
def get_eia_df(tables, key, version=2):
817+
def get_eia_df(tables, key, version=2, sleep=1):
813818
"""
814819
Function for download data from the US Government EIA and return it as a pandas dataframe/series
815820
@@ -824,7 +829,7 @@ def get_eia_df(tables, key, version=2):
824829
EIA key.
825830
version : 1 | 2
826831
API version to use, can be either 1 or 2. By default 2. As of Nov 2022, EIA is no
827-
longer support v1 of the API.
832+
longer support v1 of the API.
828833
829834
Returns
830835
-------
@@ -839,7 +844,7 @@ def get_eia_df(tables, key, version=2):
839844
if int(version) == 1:
840845
return _get_eia_df_v1(tables, key)
841846
else:
842-
return _get_eia_df_v2(tables, key)
847+
return _get_eia_df_v2(tables, key, sleep=sleep)
843848

844849

845850
def _get_eia_df_v1(tables, key):
@@ -878,7 +883,8 @@ def _get_eia_df_v1(tables, key):
878883
url = r"http://api.eia.gov/series/?api_key={}&series_id={}&out=json".format(
879884
key, tbl
880885
)
881-
tmp = json.loads(requests.get(url).text)
886+
r = requests.get(url)
887+
tmp = json.loads(r.text)
882888

883889
tf = _pd.DataFrame(tmp["series"][0]["data"], columns=["date", "value"])
884890
tf["table_name"] = tmp["series"][0]["name"]
@@ -892,7 +898,7 @@ def _get_eia_df_v1(tables, key):
892898
return eia
893899

894900

895-
def _get_eia_df_v2(tables, key):
901+
def _get_eia_df_v2(tables, key, sleep):
896902
"""
897903
Function for download data from the US Government EIA and return it as a pandas dataframe/series.
898904
API version 2.
@@ -926,17 +932,27 @@ def _get_eia_df_v2(tables, key):
926932

927933
for tbl in tables:
928934
url = f"http://api.eia.gov/v2/seriesid/{tbl}?api_key={key}"
929-
tmp = json.loads(requests.get(url).text)
930935

931-
tf = _pd.DataFrame(tmp['response']['data'], columns=["period","series-description", "value"])
936+
try:
937+
r = requests.get(url)
938+
tmp = json.loads(r.text)
939+
tf = _pd.DataFrame(
940+
tmp["response"]["data"],
941+
columns=["period", "series-description", "value"],
942+
)
943+
except:
944+
print(f"Error in table {tbl}")
945+
print(r.text)
946+
continue
932947
tf["series_id"] = tbl
933948
eia = _pd.concat([eia, tf], axis=0)
949+
time.sleep(sleep)
934950
# eia = eia.append(tf)
935-
936-
eia = eia.rename(columns={'period':'date','series-description':'table_name'})
951+
952+
eia = eia.rename(columns={"period": "date", "series-description": "table_name"})
937953
eia.loc[eia.date.str.len() < 7, "date"] += "01"
938954
eia.date = _pd.to_datetime(eia.date)
939-
return eia[['date','value','table_name','series_id']]
955+
return eia[["date", "value", "table_name", "series_id"]]
940956

941957

942958
def _check_df(df):
@@ -949,7 +965,7 @@ def _check_df(df):
949965

950966
def infer_freq(x, multiplier=False):
951967
"""
952-
Function to infer the frequency of a time series. Improvement over
968+
Function to infer the frequency of a time series. Improvement over
953969
pandas.infer_freq as it can handle missing days/holidays. Note that
954970
for business days it will return 'D' vs 'B'
955971
@@ -991,4 +1007,3 @@ def infer_freq(x, multiplier=False):
9911007
return 4
9921008
elif freq[0] == "A":
9931009
return 1
994-

src/risktools/_sims.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,16 @@ def is_iterable(x):
3131
def make_into_array(x, N):
3232
# make an array of same size as N+1
3333
if is_iterable(x):
34+
# if isinstance(x, _pd.DataFrame) == False:
35+
# x = _pd.DataFrame(x)
3436
if len(x.shape) == 2:
3537
# if a 2D array is passed, return it as is
3638
# good for stocastic volatility matrix
37-
x = _np.vstack((x[0], x))
39+
x = _np.vstack((x.iloc[0], x))
3840
return x
3941

4042
if x.shape[0] == N:
41-
x = _np.append(x[0], x)
43+
x = _np.append(x.iloc[0], x)
4244
else:
4345
raise ValueError(
4446
"if mu is passed as an iterable, it must be of length int(T/dt)"
@@ -612,7 +614,7 @@ def fitOU(spread, dt=1 / 252, log_price=False, method="OLS", verbose=False):
612614
where 252 is the number of business days in a year. Default is 1/252.
613615
Only used if method is "OLS".
614616
log_price : bool
615-
If True, the spread is assumed to be log prices and the log of the spread is taken.
617+
If True, the log of the spread is taken.
616618
Default is False.
617619
method : ['OLS', 'MLE']
618620
Method to use for parameter estimation. Default is 'OLS'.

0 commit comments

Comments
 (0)