Skip to content

Commit cdfa4ad

Browse files
authored
BUG: Solver includes certain no parameter curves (#241) (#1167)
1 parent bf84f2f commit cdfa4ad

File tree

3 files changed

+42
-7
lines changed

3 files changed

+42
-7
lines changed

python/rateslib/solver.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,16 @@
2323
from pandas.errors import PerformanceWarning
2424

2525
from rateslib import defaults
26-
from rateslib.curves import CompositeCurve, Curve, MultiCsaCurve, ProxyCurve, _BaseCurve
26+
from rateslib.curves import (
27+
CompositeCurve,
28+
Curve,
29+
MultiCsaCurve,
30+
ProxyCurve,
31+
RolledCurve,
32+
ShiftedCurve,
33+
TranslatedCurve,
34+
_BaseCurve,
35+
)
2736
from rateslib.dual import Dual, Dual2, dual_solve, gradient
2837
from rateslib.dual.newton import _solver_result
2938
from rateslib.dual.utils import _dual_float
@@ -946,6 +955,16 @@ def grad_f_fT_Pbase(
946955
return grad_s_sT_Pbas
947956

948957

958+
NO_PARAMETER_CURVES = [
959+
ProxyCurve,
960+
CompositeCurve,
961+
MultiCsaCurve,
962+
RolledCurve,
963+
ShiftedCurve,
964+
TranslatedCurve,
965+
]
966+
967+
949968
class Solver(Gradients, _WithState):
950969
"""
951970
A numerical solver to determine node values on multiple pricing objects simultaneously.
@@ -1111,8 +1130,7 @@ def __init__(
11111130
self.curves = {
11121131
curve.id: curve
11131132
for curve in list(curves) + list(surfaces)
1114-
if type(curve) not in [ProxyCurve, CompositeCurve, MultiCsaCurve]
1115-
# Proxy and Composite curves have no parameters of their own
1133+
if type(curve) not in NO_PARAMETER_CURVES
11161134
}
11171135
self.variables = ()
11181136
for curve in self.curves.values():
@@ -1142,8 +1160,8 @@ def __init__(
11421160
{
11431161
curve.id: curve
11441162
for curve in curves
1145-
if type(curve) in [ProxyCurve, CompositeCurve, MultiCsaCurve]
1146-
# Proxy and Composite curves added to the collection without variables
1163+
if type(curve) in NO_PARAMETER_CURVES
1164+
# no parameter curves added to the collection without variables
11471165
},
11481166
)
11491167
curve_collection.extend(curves)

python/tests/curves/test_sw.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def test_init():
2626
ufr=4.2,
2727
)
2828
result = sw.rate(dt(2001, 1, 1), "1b")
29-
expected = 3.3882896759093173
29+
expected = 3.3906104222626796
3030
assert abs(result - expected) < 1e-5
31-
assert sw.meta.convention == "Act365.25"
31+
assert sw.meta.convention == Convention.Act365_25
3232

3333

3434
def test_cache():

python/tests/test_solver.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2822,3 +2822,20 @@ def test_solver_validation_control(self):
28222822
def test_objects_ad_attribute(obj):
28232823
result = getattr(obj, "_ad", None)
28242824
assert result is not None
2825+
2826+
2827+
@pytest.mark.parametrize("label", ["shift", "rolled", "translated"])
2828+
def test_curves_without_their_own_params(label):
2829+
curve = Curve({dt(2000, 1, 1): 1.0, dt(2001, 1, 1): 1.0}, id="curve")
2830+
_map = {
2831+
"shift": curve.shift(5, id="shift"),
2832+
"rolled": curve.roll(5, id="rolled"),
2833+
"translated": curve.translate(dt(2000, 1, 1), id="translated"),
2834+
}
2835+
2836+
sv = Solver(
2837+
curves=[curve, _map[label]],
2838+
instruments=[IRS(dt(2000, 2, 1), dt(2000, 3, 1), spec="usd_irs", curves=["curve", label])],
2839+
s=[2.0],
2840+
)
2841+
assert sv.result["status"] == "SUCCESS"

0 commit comments

Comments
 (0)