Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions src/dcegm/interfaces/jit_large_arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
def split_structure_and_batch_info(model_structure, batch_info):
"""Splits the model structure and batch info into static parts, which we can not jit
compile and (large) arrays that we want to include in the function call for
jitting."""

struct_keys_not_for_jit = [
"discrete_states_names",
"state_names_without_stochastic",
"stochastic_states_names",
]
model_structure_non_jit = {
key: model_structure[key] for key in struct_keys_not_for_jit
}
model_structure_jit = model_structure.copy()
# Remove non-jittable items
for key in struct_keys_not_for_jit:
model_structure_jit.pop(key, None)

# Remove non-jittable items from batch_info
batch_info_jit = batch_info.copy()
batch_info_non_jit = {
"two_period_model": batch_info["two_period_model"],
}
batch_info_jit.pop("two_period_model", None)
# If it is not a two period model, there is more
if not batch_info["two_period_model"]:
batch_info_non_jit["n_segments"] = batch_info["n_segments"]
batch_info_jit.pop("n_segments", None)
for batch_id in range(batch_info_non_jit["n_segments"]):
batch_key = f"batches_info_segment_{batch_id}"
batch_info_non_jit[batch_key] = {}
batch_info_non_jit[batch_key]["batches_cover_all"] = batch_info[batch_key][
"batches_cover_all"
]
batch_info_jit[batch_key].pop("batches_cover_all", None)

return (
model_structure_jit,
batch_info_jit,
model_structure_non_jit,
batch_info_non_jit,
)


def merge_non_jit_and_jit_model_structure(model_structure_jit, model_structure_non_jit):
"""Generate one model_structure to handle inside the package functions."""
model_structure = {
**model_structure_jit,
**model_structure_non_jit,
}
return model_structure


def merg_non_jit_batch_info_and_jit_batch_info(batch_info_jit, batch_info_non_jit):
batch_info = {
**batch_info_jit,
"two_period_model": batch_info_non_jit["two_period_model"],
}
if not batch_info_non_jit["two_period_model"]:
batch_info["n_segments"] = batch_info_non_jit["n_segments"]
for batch_id in range(batch_info_non_jit["n_segments"]):
batch_key = f"batches_info_segment_{batch_id}"
batch_info[batch_key]["batches_cover_all"] = batch_info_non_jit[batch_key][
"batches_cover_all"
]
return batch_info
198 changes: 136 additions & 62 deletions src/dcegm/interfaces/model_class.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pickle as pkl
from functools import partial
from grp import struct_group
from typing import Callable, Dict

import jax
Expand All @@ -15,6 +16,11 @@
get_n_state_choice_period,
validate_stochastic_transition,
)
from dcegm.interfaces.jit_large_arrays import (
merg_non_jit_batch_info_and_jit_batch_info,
merge_non_jit_and_jit_model_structure,
split_structure_and_batch_info,
)
from dcegm.interfaces.sol_interface import model_solved
from dcegm.law_of_motion import calc_cont_grids_next_period
from dcegm.likelihood import create_individual_likelihood_function
Expand Down Expand Up @@ -124,51 +130,6 @@ def __init__(
else:
self.alternative_sim_funcs = None

def set_alternative_sim_funcs(
self, alternative_sim_specifications, alternative_specs=None
):
if alternative_specs is None:
self.alternative_sim_specs = self.model_specs
alternative_specs_without_jax = self.specs_without_jax
else:
self.alternative_sim_specs = jax.tree_util.tree_map(
try_jax_array, alternative_specs
)
alternative_specs_without_jax = alternative_specs

alternative_sim_funcs = generate_alternative_sim_functions(
model_specs=alternative_specs_without_jax,
model_specs_jax=self.alternative_sim_specs,
**alternative_sim_specifications,
)
self.alternative_sim_funcs = alternative_sim_funcs

def backward_induction_inner_jit(self, params):
return backward_induction(
params=params,
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
income_shock_weights=self.income_shock_weights,
model_config=self.model_config,
batch_info=self.batch_info,
model_funcs=self.model_funcs,
model_structure=self.model_structure,
)

def get_fast_solve_func(self):
backward_jit = jax.jit(
partial(
backward_induction,
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
income_shock_weights=self.income_shock_weights,
model_config=self.model_config,
batch_info=self.batch_info,
model_funcs=self.model_funcs,
model_structure=self.model_structure,
)
)

return backward_jit

def solve(self, params, load_sol_path=None, save_sol_path=None):
"""Solve a discrete-continuous life-cycle model using the DC-EGM algorithm.

Expand Down Expand Up @@ -198,8 +159,14 @@ def solve(self, params, load_sol_path=None, save_sol_path=None):
if load_sol_path is not None:
sol_dict = pkl.load(open(load_sol_path, "rb"))
else:
value, policy, endog_grid = self.backward_induction_inner_jit(
params_processed
value, policy, endog_grid = backward_induction(
params=params_processed,
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
income_shock_weights=self.income_shock_weights,
model_config=self.model_config,
model_funcs=self.model_funcs,
model_structure=self.model_structure,
batch_info=self.batch_info,
)
sol_dict = {
"value": value,
Expand Down Expand Up @@ -245,8 +212,14 @@ def solve_and_simulate(
if load_sol_path is not None:
sol_dict = pkl.load(open(load_sol_path, "rb"))
else:
value, policy, endog_grid = self.backward_induction_inner_jit(
params_processed
value, policy, endog_grid = backward_induction(
params=params_processed,
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
income_shock_weights=self.income_shock_weights,
model_config=self.model_config,
model_funcs=self.model_funcs,
model_structure=self.model_structure,
batch_info=self.batch_info,
)

sol_dict = {
Expand Down Expand Up @@ -274,14 +247,69 @@ def solve_and_simulate(
sim_df = create_simulation_df(sim_dict)
return sim_df

def get_solve_func(self):
"""Create a fast function for solving that is jit compiled in the first call."""

(
model_structure_for_jit,
batch_info_for_jit,
model_structure_non_jit,
batch_info_non_jit,
) = split_structure_and_batch_info(self.model_structure, self.batch_info)

def solve_function_to_jit(params, model_structure_jit, batch_info_jit):
params_processed = process_params(params, self.params_check_info)

# Merge back parts together. The non_jit objects are fixed in the closure.
model_structure = merge_non_jit_and_jit_model_structure(
model_structure_jit, model_structure_non_jit
)
batch_info = merg_non_jit_batch_info_and_jit_batch_info(
batch_info_jit, batch_info_non_jit
)

# Solve the model.
value, policy, endog_grid = backward_induction(
params=params_processed,
model_structure=model_structure,
batch_info=batch_info,
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
income_shock_weights=self.income_shock_weights,
model_config=self.model_config,
model_funcs=self.model_funcs,
)

return value, policy, endog_grid

solve_func = jax.jit(solve_function_to_jit)

# Generate the function. The user only needs to provide params, but we call with the objects for jit.
def solve_function(params):
"""Solve the model for given params."""
value, policy, endog_grid = solve_func(
params, model_structure_for_jit, batch_info_for_jit
)
model_solved_class = model_solved(
model=self,
params=params,
value=value,
policy=policy,
endog_grid=endog_grid,
)
return model_solved_class

return solve_function

def get_solve_and_simulate_func(
self,
states_initial,
seed,
slow_version=False,
):
"""Create a fast function for solving and simulation that is jit compiled in the
first call."""

sim_func = lambda params, value, policy, endog_gid: simulate_all_periods(
# Fix everything except params, solution of the model and model_structure which contains large arrays.
sim_func = lambda params, value, policy, endog_gid, model_structure: simulate_all_periods(
states_initial=states_initial,
n_periods=self.model_config["n_periods"],
params=params,
Expand All @@ -290,34 +318,59 @@ def get_solve_and_simulate_func(
policy_solved=policy,
value_solved=value,
model_config=self.model_config,
model_structure=self.model_structure,
model_structure=model_structure,
model_funcs=self.model_funcs,
alt_model_funcs_sim=self.alternative_sim_funcs,
)

def solve_and_simulate_function_to_jit(params):
(
model_structure_for_jit,
batch_info_for_jit,
model_structure_non_jit,
batch_info_non_jit,
) = split_structure_and_batch_info(self.model_structure, self.batch_info)

def solve_and_simulate_function_to_jit(
params, model_structure_jit, batch_info_jit
):
params_processed = process_params(params, self.params_check_info)
# Solve the model
value, policy, endog_grid = self.backward_induction_inner_jit(
params_processed

# Merge back parts together. The non_jit objects are fixed in the closure.
model_structure = merge_non_jit_and_jit_model_structure(
model_structure_jit, model_structure_non_jit
)
batch_info = merg_non_jit_batch_info_and_jit_batch_info(
batch_info_jit, batch_info_non_jit
)

# Solve the model.
value, policy, endog_grid = backward_induction(
params=params_processed,
model_structure=model_structure,
batch_info=batch_info,
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
income_shock_weights=self.income_shock_weights,
model_config=self.model_config,
model_funcs=self.model_funcs,
)

sim_dict = sim_func(
params=params_processed,
value=value,
policy=policy,
endog_gid=endog_grid,
model_structure=model_structure,
)

return sim_dict

if slow_version:
solve_simulate_func = solve_and_simulate_function_to_jit
else:
solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit)
solve_simulate_func = jax.jit(solve_and_simulate_function_to_jit)

# Generate the function. The user only needs to provide params, but we call with the objects for jit.
def solve_and_simulate_function(params):
sim_dict = solve_simulate_func(params)
sim_dict = solve_simulate_func(
params, model_structure_for_jit, batch_info_for_jit
)
df = create_simulation_df(sim_dict)
return df

Expand All @@ -335,11 +388,13 @@ def create_experimental_ll_func(
):

return create_individual_likelihood_function(
income_shock_draws_unscaled=self.income_shock_draws_unscaled,
income_shock_weights=self.income_shock_weights,
batch_info=self.batch_info,
model_structure=self.model_structure,
model_config=self.model_config,
model_funcs=self.model_funcs,
model_specs=self.model_specs,
backwards_induction_inner_jit=self.backward_induction_inner_jit,
observed_states=observed_states,
observed_choices=observed_choices,
params_all=params_all,
Expand Down Expand Up @@ -475,3 +530,22 @@ def solve_partially(self, params, n_periods, return_candidates=False):
n_periods=n_periods,
return_candidates=return_candidates,
)

def set_alternative_sim_funcs(
self, alternative_sim_specifications, alternative_specs=None
):
if alternative_specs is None:
self.alternative_sim_specs = self.model_specs
alternative_specs_without_jax = self.specs_without_jax
else:
self.alternative_sim_specs = jax.tree_util.tree_map(
try_jax_array, alternative_specs
)
alternative_specs_without_jax = alternative_specs

alternative_sim_funcs = generate_alternative_sim_functions(
model_specs=alternative_specs_without_jax,
model_specs_jax=self.alternative_sim_specs,
**alternative_sim_specifications,
)
self.alternative_sim_funcs = alternative_sim_funcs
Loading
Loading