Skip to content

Commit 125aba1

Browse files
committed
More tests
1 parent e1d8e81 commit 125aba1

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from pathlib import Path
22

3+
import numpy as np
34
from numpy.testing import assert_array_almost_equal as aaae
45

56
import dcegm
@@ -37,3 +38,30 @@ def test_partial_solve_func():
3738
aaae(model_solved.policy, partial_sol["policy"])
3839
aaae(model_solved.value, partial_sol["value"])
3940
aaae(model_solved.endog_grid, partial_sol["endog_grid"])
41+
42+
state_choices = model_solved.model_structure["state_choice_space"]
43+
choices = state_choices[:, -1]
44+
states_dict = {
45+
state: state_choices[:, id]
46+
for id, state in enumerate(
47+
model_solved.model_structure["discrete_states_names"]
48+
)
49+
}
50+
states_dict["assets_begin_of_period"] = model_solved.endog_grid[:, 5]
51+
value_states_all_choices = model_solved.choice_values_for_states(states=states_dict)
52+
53+
# Take in each row the value corresponding to the choice made
54+
value_choices = value_states_all_choices[
55+
np.arange(value_states_all_choices.shape[0]), choices
56+
]
57+
58+
aaae(model_solved.value[:, 5], value_choices)
59+
60+
# Same for policies
61+
policy_states_all_choices = model_solved.choice_policies_for_states(
62+
states=states_dict
63+
)
64+
policy_choices = policy_states_all_choices[
65+
np.arange(policy_states_all_choices.shape[0]), choices
66+
]
67+
aaae(model_solved.policy[:, 5], policy_choices)

0 commit comments

Comments
 (0)