|
1 | 1 | from pathlib import Path |
2 | 2 |
|
| 3 | +import numpy as np |
3 | 4 | from numpy.testing import assert_array_almost_equal as aaae |
4 | 5 |
|
5 | 6 | import dcegm |
@@ -37,3 +38,30 @@ def test_partial_solve_func(): |
37 | 38 | aaae(model_solved.policy, partial_sol["policy"]) |
38 | 39 | aaae(model_solved.value, partial_sol["value"]) |
39 | 40 | aaae(model_solved.endog_grid, partial_sol["endog_grid"]) |
| 41 | + |
| 42 | + state_choices = model_solved.model_structure["state_choice_space"] |
| 43 | + choices = state_choices[:, -1] |
| 44 | + states_dict = { |
| 45 | + state: state_choices[:, id] |
| 46 | + for id, state in enumerate( |
| 47 | + model_solved.model_structure["discrete_states_names"] |
| 48 | + ) |
| 49 | + } |
| 50 | + states_dict["assets_begin_of_period"] = model_solved.endog_grid[:, 5] |
| 51 | + value_states_all_choices = model_solved.choice_values_for_states(states=states_dict) |
| 52 | + |
| 53 | + # Take in each row the value corresponding to the choice made |
| 54 | + value_choices = value_states_all_choices[ |
| 55 | + np.arange(value_states_all_choices.shape[0]), choices |
| 56 | + ] |
| 57 | + |
| 58 | + aaae(model_solved.value[:, 5], value_choices) |
| 59 | + |
| 60 | + # Same for policies |
| 61 | + policy_states_all_choices = model_solved.choice_policies_for_states( |
| 62 | + states=states_dict |
| 63 | + ) |
| 64 | + policy_choices = policy_states_all_choices[ |
| 65 | + np.arange(policy_states_all_choices.shape[0]), choices |
| 66 | + ] |
| 67 | + aaae(model_solved.policy[:, 5], policy_choices) |
0 commit comments