Skip to content

Commit a0e6197

Browse files
committed
1 parent 7e5dd10 commit a0e6197

File tree

5 files changed

+61
-85
lines changed

5 files changed

+61
-85
lines changed

chainladder/core/dunders.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -60,73 +60,84 @@ def _compatibility_check(self, x, y):
6060
return x, y
6161

6262
def _prep_index(self, x, y):
63-
""" Preps index and column axes for arithmetic """
6463
if x.kdims.shape[0] == 1 and y.kdims.shape[0] > 1:
65-
# Broadcast x to y
6664
x.kdims = y.kdims
6765
x.key_labels = y.key_labels
6866
return x, y
6967
if x.kdims.shape[0] > 1 and y.kdims.shape[0] == 1:
70-
# Broadcast y to x
7168
y.kdims = x.kdims
7269
y.key_labels = x.key_labels
7370
return x, y
7471
if x.kdims.shape[0] == y.kdims.shape[0] == 1 and x.key_labels != y.key_labels:
75-
# Broadcast to the triangle with a larger multi-index
7672
kdims = x.kdims if len(x.key_labels) > len(y.key_labels) else y.kdims
77-
y.kdims = x.kdims = kdims
7873
key_labels = x.key_labels if len(x.key_labels) > len(y.key_labels) else y.key_labels
79-
y.key_labels = x.key_labels = key_labels
74+
x.kdims = y.kdims = kdims
75+
x.key_labels = y.key_labels = key_labels
8076
return x, y
81-
a, b = set(x.key_labels), set(y.key_labels)
82-
common = a.intersection(b)
83-
if common in [a, b] and (a != b or (a == b and x.kdims.shape[0] != y.kdims.shape[0])):
84-
# If index labels are subset of other triangle index labels
85-
x = x.groupby(list(common))
86-
y = y.groupby(list(common))
87-
return x, y
88-
if common not in [a, b]:
89-
raise ValueError('Index broadcasting is ambiguous between', str(a), 'and', str(b))
90-
if (
91-
x.key_labels == y.key_labels
92-
and x.kdims.shape[0] == y.kdims.shape[0]
93-
and y.kdims.shape[0] > 1
94-
and not x.kdims is y.kdims
95-
and not x.index.equals(y.index)
96-
):
97-
# Make sure exact but unsorted index labels works
98-
x = x.sort_index()
99-
try:
100-
y = y.loc[x.index]
101-
except:
77+
78+
# Use sets for faster operations
79+
x_labels = set(x.key_labels)
80+
y_labels = set(y.key_labels)
81+
common = x_labels.intersection(y_labels)
82+
83+
if common == x_labels or common == y_labels:
84+
if x_labels != y_labels or x.kdims.shape[0] != y.kdims.shape[0]:
10285
x = x.groupby(list(common))
10386
y = y.groupby(list(common))
87+
elif x.kdims.shape[0] > 1 and not np.array_equal(x.kdims, y.kdims) and not x.index.equals(y.index):
88+
x = x.sort_index()
89+
try:
90+
y = y.loc[x.index]
91+
except:
92+
x = x.groupby(list(common))
93+
y = y.groupby(list(common))
94+
return x, y
95+
96+
if common != x_labels and common != y_labels:
97+
raise ValueError('Index broadcasting is ambiguous between ' + str(x_labels) + ' and ' + str(y_labels))
98+
10499
return x, y
105100

106101
def _prep_columns(self, x, y):
107102
x_backend, y_backend = x.array_backend, y.array_backend
103+
108104
if len(x.columns) == 1 and len(y.columns) > 1:
109105
x.vdims = y.vdims
110106
elif len(y.columns) == 1 and len(x.columns) > 1:
111107
y.vdims = x.vdims
112-
elif len(y.columns) == 1 and len(x.columns) == 1 and x.columns != y.columns:
108+
elif len(y.columns) == len(x.columns) == 1 and x.columns != y.columns:
113109
y.vdims = x.vdims
114-
elif x.shape[1] == y.shape[1] and np.all(x.columns == y.columns):
115-
pass
110+
elif x.shape[1] == y.shape[1] and np.array_equal(x.columns, y.columns):
111+
return x, y
116112
else:
117-
col_union = list(x.columns) + [
118-
item for item in y.columns if item not in x.columns
119-
]
120-
for item in [item for item in col_union if item not in x.columns]:
121-
x[item] = 0
122-
x = x[col_union]
123-
for item in [item for item in col_union if item not in y.columns]:
124-
y[item] = 0
125-
y = y[col_union]
126-
x, y = (
127-
x.set_backend(x_backend, inplace=True),
128-
y.set_backend(y_backend, inplace=True),
129-
)
113+
# Use sets for faster operations
114+
x_cols = set(x.columns)
115+
y_cols = set(y.columns)
116+
117+
# Find columns to add to each triangle
118+
cols_to_add_to_x = y_cols - x_cols
119+
cols_to_add_to_y = x_cols - y_cols
120+
121+
# Create new columns only if necessary
122+
if cols_to_add_to_x:
123+
new_x_cols = list(x.columns) + list(cols_to_add_to_x)
124+
x = x.reindex(columns=new_x_cols, fill_value=0)
125+
126+
if cols_to_add_to_y:
127+
new_y_cols = list(y.columns) + list(cols_to_add_to_y)
128+
y = y.reindex(columns=new_y_cols, fill_value=0)
129+
130+
# Ensure both triangles have the same column order
131+
final_cols = list(x_cols | y_cols)
132+
x = x[final_cols]
133+
y = y[final_cols]
134+
135+
# Reset backends only if they've changed
136+
if x.array_backend != x_backend:
137+
x = x.set_backend(x_backend, inplace=True)
138+
if y.array_backend != y_backend:
139+
y = y.set_backend(y_backend, inplace=True)
140+
130141
return x, y
131142

132143
def _prep_origin_development(self, obj, other):

chainladder/core/tests/test_triangle.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import chainladder as cl
22
import pandas as pd
3-
import polars as pl
43
import numpy as np
5-
import copy
64
import pytest
75
import io
86
from datetime import datetime
@@ -746,9 +744,7 @@ def test_halfyear_development():
746744
["2012-01-01", "2013-12-31", "incurred", 200.0],
747745
]
748746

749-
df_polars = pl.DataFrame(data)
750-
df_polars.columns = ["origin", "val_date", "idx", "value"]
751-
747+
752748
assert (
753749
type(
754750
cl.Triangle(
@@ -760,33 +756,4 @@ def test_halfyear_development():
760756
cumulative=True,
761757
)
762758
)
763-
) == cl.Triangle
764-
765-
assert (
766-
type(
767-
cl.Triangle(
768-
data=df_polars,
769-
index="idx",
770-
columns="value",
771-
origin="origin",
772-
development="val_date",
773-
cumulative=True,
774-
)
775-
)
776-
) == cl.Triangle
777-
778-
assert cl.Triangle(
779-
data=pd.DataFrame(data, columns=["origin", "val_date", "idx", "value"]),
780-
index="idx",
781-
columns="value",
782-
origin="origin",
783-
development="val_date",
784-
cumulative=True,
785-
) == cl.Triangle(
786-
data=df_polars,
787-
index="idx",
788-
columns="value",
789-
origin="origin",
790-
development="val_date",
791-
cumulative=True,
792-
)
759+
) == cl.Triangle

chainladder/development/learning.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(self, estimator_ml=None, y_ml=None, autoregressive=False,
5353
self.y_ml=y_ml
5454
self.weight_ml = weight_ml
5555
self.autoregressive=autoregressive
56-
self.fit_incrementals=fit_incrementals
56+
self.fit_incrementals = fit_incrementals
5757

5858
def _get_y_names(self):
5959
""" private function to get the response column name"""
@@ -153,7 +153,7 @@ def fit(self, X, y=None, sample_weight=None):
153153
Parameters
154154
----------
155155
X : Triangle-like
156-
Set of LDFs to which the munich adjustment will be applied.
156+
Set of LDFs to which the estimator will be applied.
157157
y : None
158158
Ignored, use y_ml to set a reponse variable for the ML algorithm
159159
sample_weight : None
@@ -180,7 +180,7 @@ def fit(self, X, y=None, sample_weight=None):
180180
self.df_ = df
181181
# Fit model
182182
self.estimator_ml.fit(df, self.y_ml_.fit_transform(df).squeeze())
183-
#return self
183+
#return selffit_incrementals
184184
self.triangle_ml_ = self._get_triangle_ml(df)
185185
return self
186186

chainladder/workflow/voting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def fit(self, X, y, sample_weight=None):
124124
self.estimators_ = Parallel(n_jobs=self.n_jobs)(
125125
delayed(_fit_single_estimator)(
126126
clone(clf), X, y,
127-
sample_weight=sample_weight,
127+
fit_params=dict(sample_weight=sample_weight),
128128
message_clsname='VotingChainladder',
129129
message=self._log_message(names[idx],
130130
idx + 1, len(clfs))

environment-dev.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,11 @@ dependencies:
1414
- ipykernel
1515

1616
- pandas
17-
- polars
1817
- scikit-learn
1918
- sparse
20-
- numba
2119
- dill
2220
- patsy
23-
- matplotlib
21+
- matplotlib-base
2422

2523
# testing
2624
- lxml

0 commit comments

Comments
 (0)