Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ jobs:
flake8 .
- name: ufmt (formatting check)
run: |
echo "Checking for formatting issues..."
ufmt diff .
ufmt check .

pyre:
Expand Down
16 changes: 10 additions & 6 deletions balance/sample_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,14 +823,21 @@ def model(
self: "Sample",
) -> Dict[str, Any] | None:
"""
Returns the name of the model used to adjust Sample if adjusted.
Returns the adjustment model dictionary if Sample has been adjusted.
Otherwise returns None.

The ``_adjustment_model`` attribute is initialized as ``None`` at the class
level and is set to a dictionary containing model details when
:meth:`adjust` is called. This method simply returns that attribute,
which will be ``None`` for unadjusted samples.

Args:
self (Sample): Sample object.

Returns:
str or None: name of model used for adjusting Sample
Dict[str, Any] or None: Dictionary containing adjustment model details
(e.g., method name, fitted model, performance metrics) if the sample
has been adjusted, otherwise None.

Examples:
.. code-block:: python
Expand All @@ -848,10 +855,7 @@ def model(
sample.model() is None
# True
"""
if hasattr(self, "_adjustment_model"):
return self._adjustment_model
else:
return None
return self._adjustment_model

def model_matrix(self: "Sample") -> pd.DataFrame:
"""
Expand Down
21 changes: 7 additions & 14 deletions balance/stats_and_plots/weighted_comparisons_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,10 +653,6 @@ def seaborn_plot_dist(
# With limiting the y axis range to (0,1)
seaborn_plot_dist(dfs1, names=["self", "unadjusted", "target"], dist_type = "kde", ylim = (0,1))
"""
# Provide default names if not specified
if names is None:
names = [f"df_{i}" for i in range(len(dfs))]

# Set default dist_type
dist_type_resolved: Literal["qq", "hist", "kde", "ecdf"]
if dist_type is None:
Expand All @@ -671,10 +667,6 @@ def seaborn_plot_dist(
if names is None:
names = [f"df_{i}" for i in range(len(dfs))]

# Type narrowing for names parameter
if names is None:
names = []

# Choose set of variables to plot
variables = choose_variables(*(d["df"] for d in dfs), variables=variables)
logger.debug(f"plotting variables {variables}")
Expand Down Expand Up @@ -1348,12 +1340,13 @@ def naming_legend(object_name: str, names_of_dfs: List[str]) -> str:
naming_legend('self', ['self', 'target']) #'sample'
naming_legend('other_name', ['self', 'target']) #'other_name'
"""
if object_name in names_of_dfs:
return {
"unadjusted": "sample",
"self": "adjusted" if "unadjusted" in names_of_dfs else "sample",
"target": "population",
}[object_name]
name_mapping = {
"unadjusted": "sample",
"self": "adjusted" if "unadjusted" in names_of_dfs else "sample",
"target": "population",
}
if object_name in name_mapping:
return name_mapping[object_name]
else:
return object_name

Expand Down
59 changes: 25 additions & 34 deletions balance/stats_and_plots/weighted_comparisons_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from balance.stats_and_plots.weights_stats import _check_weights_are_valid
from balance.util import _safe_groupby_apply, _safe_replace_and_infer
from balance.utils.input_validation import (
_coerce_to_numeric_and_validate,
_extract_series_and_weights,
_is_discrete_series,
)
Expand Down Expand Up @@ -826,16 +827,16 @@ def emd(
)
)
else:
sample_vals = pd.to_numeric(sample_series, errors="coerce").dropna()
target_vals = pd.to_numeric(target_series, errors="coerce").dropna()
if sample_vals.empty or target_vals.empty:
raise ValueError("Numeric columns must contain at least one value.")
sample_w_numeric = sample_w[sample_series.index.isin(sample_vals.index)]
target_w_numeric = target_w[target_series.index.isin(target_vals.index)]
sample_vals, sample_w_numeric = _coerce_to_numeric_and_validate(
sample_series, sample_w, "Sample numeric column"
)
target_vals, target_w_numeric = _coerce_to_numeric_and_validate(
target_series, target_w, "Target numeric column"
)
out_dict[col] = float(
wasserstein_distance(
sample_vals.to_numpy(),
target_vals.to_numpy(),
sample_vals,
target_vals,
u_weights=sample_w_numeric,
v_weights=target_w_numeric,
)
Expand Down Expand Up @@ -962,21 +963,17 @@ def cvmd(
np.sum((sample_cdf - target_cdf) ** 2 * combined_pmf.to_numpy())
)
else:
sample_vals = pd.to_numeric(sample_series, errors="coerce").dropna()
target_vals = pd.to_numeric(target_series, errors="coerce").dropna()
if sample_vals.empty or target_vals.empty:
raise ValueError("Numeric columns must contain at least one value.")
sample_w_numeric = sample_w[sample_series.index.isin(sample_vals.index)]
target_w_numeric = target_w[target_series.index.isin(target_vals.index)]

sample_sorted, sample_cdf = _weighted_ecdf(
sample_vals.to_numpy(), sample_w_numeric
sample_vals, sample_w_numeric = _coerce_to_numeric_and_validate(
sample_series, sample_w, "Sample numeric column"
)
target_sorted, target_cdf = _weighted_ecdf(
target_vals.to_numpy(), target_w_numeric
target_vals, target_w_numeric = _coerce_to_numeric_and_validate(
target_series, target_w, "Target numeric column"
)

sample_sorted, sample_cdf = _weighted_ecdf(sample_vals, sample_w_numeric)
target_sorted, target_cdf = _weighted_ecdf(target_vals, target_w_numeric)
combined_values, combined_weights = _combined_weights(
np.concatenate((sample_vals.to_numpy(), target_vals.to_numpy())),
np.concatenate((sample_vals, target_vals)),
np.concatenate((sample_w_numeric, target_w_numeric)),
)
sample_eval = _evaluate_ecdf(sample_sorted, sample_cdf, combined_values)
Expand Down Expand Up @@ -1101,22 +1098,16 @@ def ks(
)
out_dict[col] = float(np.max(np.abs(sample_cdf - target_cdf)))
else:
sample_vals = pd.to_numeric(sample_series, errors="coerce").dropna()
target_vals = pd.to_numeric(target_series, errors="coerce").dropna()
if sample_vals.empty or target_vals.empty:
raise ValueError("Numeric columns must contain at least one value.")
sample_w_numeric = sample_w[sample_series.index.isin(sample_vals.index)]
target_w_numeric = target_w[target_series.index.isin(target_vals.index)]

sample_sorted, sample_cdf = _weighted_ecdf(
sample_vals.to_numpy(), sample_w_numeric
sample_vals, sample_w_numeric = _coerce_to_numeric_and_validate(
sample_series, sample_w, "Sample numeric column"
)
target_sorted, target_cdf = _weighted_ecdf(
target_vals.to_numpy(), target_w_numeric
)
combined_values = np.unique(
np.concatenate((sample_vals.to_numpy(), target_vals.to_numpy()))
target_vals, target_w_numeric = _coerce_to_numeric_and_validate(
target_series, target_w, "Target numeric column"
)

sample_sorted, sample_cdf = _weighted_ecdf(sample_vals, sample_w_numeric)
target_sorted, target_cdf = _weighted_ecdf(target_vals, target_w_numeric)
combined_values = np.unique(np.concatenate((sample_vals, target_vals)))
sample_eval = _evaluate_ecdf(sample_sorted, sample_cdf, combined_values)
target_eval = _evaluate_ecdf(target_sorted, target_cdf, combined_values)
out_dict[col] = float(np.max(np.abs(sample_eval - target_eval)))
Expand Down
8 changes: 2 additions & 6 deletions balance/stats_and_plots/weighted_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,8 @@ def ci_of_weighted_mean(
var_weighed_mean_of_v = var_of_weighted_mean(v, w, inf_rm)
z_value = norm.ppf((1 + conf_level) / 2)

if isinstance(v, pd.Series):
ci_index = v.index
elif isinstance(v, pd.DataFrame):
ci_index = v.columns
else:
ci_index = None
# After _prepare_weighted_stat_args, v is always a DataFrame
ci_index = v.columns

ci = pd.Series(
[
Expand Down
67 changes: 63 additions & 4 deletions balance/utils/input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,56 @@ def _extract_series_and_weights(
return filtered_series, filtered_weights


def _coerce_to_numeric_and_validate(
series: pd.Series,
weights: np.ndarray,
label: str,
) -> tuple[np.ndarray, np.ndarray]:
"""
Convert series to numeric, drop NaN values, and validate non-empty.

This function handles series that may contain values that cannot be
converted to numeric (e.g., non-numeric strings in an object dtype series).
It coerces such values to NaN and drops them, then validates that at least
one valid numeric value remains.

Args:
series (pd.Series): Input series to convert to numeric.
weights (np.ndarray): Weights aligned to the series.
label (str): Label for error messages.

Returns:
Tuple[np.ndarray, np.ndarray]: Numeric values and corresponding weights.

Raises:
ValueError: If no valid numeric values remain after conversion.

Examples:
.. code-block:: python

import numpy as np
import pandas as pd
from balance.utils.input_validation import _coerce_to_numeric_and_validate

vals, w = _coerce_to_numeric_and_validate(
pd.Series([1.0, 2.0, 3.0]),
np.array([1.0, 1.0, 2.0]),
"example",
)
vals.tolist()
# [1.0, 2.0, 3.0]
w.tolist()
# [1.0, 1.0, 2.0]
"""
numeric_series = pd.to_numeric(series, errors="coerce").dropna()
if numeric_series.empty:
raise ValueError(
f"{label} must contain at least one valid numeric value after conversion."
)
numeric_weights = weights[series.index.isin(numeric_series.index)]
return numeric_series.to_numpy(), numeric_weights


def _is_discrete_series(series: pd.Series) -> bool:
"""
Determine whether a series should be treated as discrete for comparisons.
Expand Down Expand Up @@ -183,10 +233,19 @@ def _check_weighting_methods_input(
# This is so to avoid various cyclic imports (since various files call sample_class, and then sample_class also calls these files)
# TODO: (p2) move away from this method once we restructure Sample and BalanceDF objects...
def _isinstance_sample(obj: Any) -> bool:
try:
from balance import sample_class
except ImportError:
return False
"""Check if an object is an instance of Sample.

The import is done inside the function to avoid circular import issues at
module load time. Since this module is part of the balance package, the
import will always succeed when the function is called.

Args:
obj: The object to check.

Returns:
bool: True if obj is a Sample instance, False otherwise.
"""
from balance import sample_class

return isinstance(obj, sample_class.Sample)

Expand Down
Loading
Loading