Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions causallib/estimation/base_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,21 +103,22 @@ def compute_weight_matrix(self, X, a, use_stabilized=None, **kwargs):
raise NotImplementedError

@staticmethod
def _compute_stratified_weighted_aggregate(y, sample_weight=None,
def _compute_stratified_weighted_aggregate(y, sample_weight=None, normalize=True,
stratify_by=None, treatment_values=None):
"""
Calculates aggregation of `y` weighted by `sample_weights` stratified by `stratify_by` variable.

Args:
y (pd.Series): The variable to aggregate (num_subjects,).
sample_weight (pd.Series|None): Individual (sample) weights calculated.
Used to achieved unbiased average outcome.
If not provided, gives equal weights to every example.
Used to achieved unbiased average outcome.
If not provided, gives equal weights to every example.
normalize (bool): Whether to normalize the weights to sum to 1 within each strata.
stratify_by (pd.Series|None): Categorical variable to stratify according to (num_subjects,).
Namely, aggregate within subgroups sharing the same values.
If not provided, the aggregation is on the entire
treatment_values (Any): Subset of values to stratify on from `stratify_by`.
If not supplied, all available stratification values are used.
If not supplied, all available stratification values are used.

Returns:
pd.Series[Any, float]: Series which index are treatment values, and the values are numbers - the
Expand All @@ -133,11 +134,16 @@ def _compute_stratified_weighted_aggregate(y, sample_weight=None,
res = {}
for treatment_value in treatment_values:
subgroup_mask = stratify_by == treatment_value
aggregated_value = np.average(y[subgroup_mask], weights=sample_weight[subgroup_mask])
if normalize:
aggregated_value = np.average(y[subgroup_mask], weights=sample_weight[subgroup_mask])
else:
aggregated_value = np.sum(y[subgroup_mask] * sample_weight[subgroup_mask])/len(y)

res[treatment_value] = aggregated_value
res = pd.Series(res)
return res


def evaluate_balancing(self, X, a, y, w):
pass # TODO: implement: (1) table one with smd (2) gather lots of metric (ks, kl, smd) (3) plot CDF of each feature

Expand Down
13 changes: 11 additions & 2 deletions causallib/estimation/ipw.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,14 +201,17 @@ def compute_propensity_matrix(self, X, a=None, clip_min=None, clip_max=None):

return probabilities

def estimate_population_outcome(self, X, a, y, w=None, treatment_values=None):
def estimate_population_outcome(self, X, a, y, ipw_estimator, w=None, treatment_values=None):
"""
Calculates weighted population outcome for each subgroup stratified by treatment assignment.

Args:
X (pd.DataFrame): Covariate matrix of size (num_subjects, num_features).
a (pd.Series): Treatment assignment of size (num_subjects,).
y (pd.Series): Observed outcome of size (num_subjects,).
ipw_estimator (str): The type of estimator to use for the inverse probability weighting.
Horvitz-Thompson estimator is unbiased but high variance, while Hajek estimator
is biased but lower variance.
w (pd.Series | None): Individual (sample) weights calculated. Used to achieved unbiased average outcome.
If not provided, will be calculated on the data.
treatment_values (Any): Desired treatment value/s to stratify upon.
Expand All @@ -221,8 +224,14 @@ def estimate_population_outcome(self, X, a, y, w=None, treatment_values=None):
"""
if w is None:
w = self.compute_weights(X, a)
res = self._compute_stratified_weighted_aggregate(y, sample_weight=w, stratify_by=a,
if ipw_estimator == 'Horvitz-Thompson':
res = self._compute_stratified_weighted_aggregate(y, sample_weight=w, stratify_by=a, normalize=False, treatment_values=treatment_values)
elif ipw_estimator == 'Hajek':
res = self._compute_stratified_weighted_aggregate(y, sample_weight=w, normalize=True, stratify_by=a,
treatment_values=treatment_values)
else:
raise ValueError(f"Unknown ipw_estimator: {ipw_estimator}, not implemented in the current package version.")

return res

@staticmethod
Expand Down