Skip to content

Commit 2578cce

Browse files
authored
Merge branch 'main' into vs_module
2 parents a5dd60f + 8e672f2 commit 2578cce

40 files changed

+7727
-5138
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@ jobs:
2424
- name: Setup environment
2525
run: pip install -e .[test]
2626
- name: Run doctests
27-
run: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py
27+
run: pytest --doctest-modules --ignore=causalpy/tests/ causalpy/ --config-file=causalpy/tests/conftest.py --no-cov
2828
- name: Run extra tests
29-
run: pytest docs/source/.codespell/test_notebook_to_markdown.py
29+
run: pytest docs/source/.codespell/test_notebook_to_markdown.py --no-cov
3030
- name: Run tests
3131
run: pytest --cov-report=xml --no-cov-on-fail
3232
- name: Check codespell for notebooks

.github/workflows/uml.yml

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,25 @@ jobs:
5858
- name: Create PR for changes
5959
if: steps.changes.outputs.changes_exist == 'true'
6060
run: |
61-
git checkout -b update-uml-diagrams
61+
BRANCH_NAME="update-uml-diagrams"
62+
63+
# Delete the branch locally and remotely if it exists
64+
git branch -D $BRANCH_NAME 2>/dev/null || true
65+
git push origin --delete $BRANCH_NAME 2>/dev/null || true
66+
67+
# Create and push the new branch
68+
git checkout -b $BRANCH_NAME
6269
git commit -m "Update UML Diagrams"
63-
git push -u origin update-uml-diagrams
70+
git push -u origin $BRANCH_NAME
71+
72+
# Create PR (will fail gracefully if PR already exists)
6473
gh pr create \
6574
--base main \
6675
--title "Update UML Diagrams" \
6776
--body "This PR updates the UML diagrams
6877
This PR was created automatically by the [UML workflow](https://github.com/pymc-labs/CausalPy/blob/main/.github/workflows/uml.yml).
6978
See the logs [here](https://github.com/pymc-labs/CausalPy/actions/workflows/uml.yml) for more details." \
7079
--label "no releasenotes" \
71-
--reviewer drbenvincent
80+
--reviewer drbenvincent || echo "PR may already exist"
7281
env:
7382
GH_TOKEN: ${{ github.token }}

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,3 +14,5 @@ dist/
1414
docs/build/
1515
docs/jupyter_execute/
1616
docs/source/api/generated/
17+
18+
.cursor/

.pre-commit-config.yaml

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,24 @@ repos:
2121
- id: end-of-file-fixer
2222
exclude_types: [svg]
2323
- id: check-yaml
24+
- id: check-toml
25+
- id: check-json
2426
- id: check-added-large-files
2527
exclude: &exclude_pattern '(iv_weak_instruments|its_lift_test)\.ipynb'
2628
args: ["--maxkb=1500"]
29+
- id: check-merge-conflict
30+
- id: check-case-conflict
31+
- id: mixed-line-ending
2732
- repo: https://github.com/astral-sh/ruff-pre-commit
28-
rev: v0.14.5
33+
rev: v0.14.8
2934
hooks:
3035
# Run the linter
3136
- id: ruff
3237
types_or: [ python, pyi, jupyter ]
3338
args: [ --fix ]
39+
# Exclude docs/ to avoid applying strict linting rules to example notebooks
40+
# Remove this exclusion if you want to enforce strict rules on documentation
41+
exclude: ^docs/
3442
# Run the formatter
3543
- id: ruff-format
3644
types_or: [ python, pyi, jupyter ]
@@ -49,9 +57,13 @@ repos:
4957
# Support pyproject.toml configuration
5058
- tomli
5159
- repo: https://github.com/pre-commit/mirrors-mypy
52-
rev: v1.18.2
60+
rev: v1.19.0
5361
hooks:
5462
- id: mypy
5563
args: [--ignore-missing-imports]
5664
files: ^causalpy/
5765
additional_dependencies: [numpy>=1.20, pandas-stubs]
66+
- repo: https://github.com/abravalheri/validate-pyproject
67+
rev: v0.24.1
68+
hooks:
69+
- id: validate-pyproject

AGENTS.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,15 @@
4141
- **Custom exceptions**: Use project-specific exceptions from `causalpy.custom_exceptions`: `FormulaException`, `DataException`, `BadIndexException`
4242
- **File organization**: Experiments in `causalpy/experiments/`, PyMC models in `causalpy/pymc_models.py`, scikit-learn models in `causalpy/skl_models.py`
4343

44+
## Code quality checks
45+
46+
- **Before committing**: Always run `pre-commit run --all-files` to ensure all checks pass (linting, formatting, type checking)
47+
- **Quick check**: Run `ruff check causalpy/` for fast linting feedback during development
48+
- **Auto-fix**: Run `ruff check --fix causalpy/` to automatically fix many linting issues
49+
- **Format**: Run `ruff format causalpy/` to format code according to project standards
50+
- **Linting rules**: Project uses strict linting (F, B, UP, C4, SIM, I) to catch bugs and enforce modern Python patterns
51+
- **Note**: Documentation notebooks in `docs/` are excluded from strict linting rules
52+
4453
## Type Checking
4554

4655
- **Tool**: MyPy

causalpy/data/PISA18sampleScale.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,4 +98,4 @@ PV1READ,Female,ESCS,METASUM,PERFEED,JOYREAD,MASTGOAL,ADAPTIVITY,TEACHINT,SCREADD
9898
1.382758437,0,1.112755087,1.024329216,-1.363017023,0.546695917,-0.300835494,0.488206717,0.990856172,0.19103636,-0.063690149
9999
-0.180166117,0,-0.903784153,1.416120964,-0.789077206,0.3589234,-2.209705037,-1.353064842,-0.668887592,0.689393741,-0.929939088
100100
-0.138452609,0,-1.523831485,-0.171666648,0.021680557,-0.050770988,1.521829065,-0.8111262,-0.039406676,0.66554889,-0.851380406
101-
0.907727459,1,0.115773982,1.024329216,1.478217432,0.461175761,-0.873789642,0.080450276,-0.668887592,-0.265544842,-0.063690149
101+
0.907727459,1,0.115773982,1.024329216,1.478217432,0.461175761,-0.873789642,0.080450276,-0.668887592,-0.265544842,-0.063690149

causalpy/data/simulate_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def impact(x: np.ndarray) -> np.ndarray:
308308

309309
def generate_ancova_data(
310310
N: int = 200,
311-
pre_treatment_means: np.ndarray = np.array([10, 12]),
311+
pre_treatment_means: np.ndarray | None = None,
312312
treatment_effect: int = 2,
313313
sigma: int = 1,
314314
) -> pd.DataFrame:
@@ -324,6 +324,8 @@ def generate_ancova_data(
324324
... )
325325
>>> df.to_csv(pathlib.Path.cwd() / "ancova_data.csv", index=False) # doctest: +SKIP
326326
"""
327+
if pre_treatment_means is None:
328+
pre_treatment_means = np.array([10, 12])
327329
group = np.random.choice(2, size=N)
328330
pre = np.random.normal(loc=pre_treatment_means[group])
329331
post = pre + treatment_effect * group + np.random.normal(size=N) * sigma

causalpy/experiments/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717

1818
from abc import abstractmethod
19-
from typing import Any, Literal, Union
19+
from typing import Any, Literal
2020

2121
import arviz as az
2222
import matplotlib.pyplot as plt
@@ -54,7 +54,7 @@ class BaseExperiment:
5454
supports_bayes: bool
5555
supports_ols: bool
5656

57-
def __init__(self, model: Union[PyMCModel, RegressorMixin] | None = None) -> None:
57+
def __init__(self, model: PyMCModel | RegressorMixin | None = None) -> None:
5858
# Ensure we've made any provided Scikit Learn model (as identified as being type
5959
# RegressorMixin) compatible with CausalPy by appending our custom methods.
6060
if isinstance(model, RegressorMixin):
@@ -141,7 +141,7 @@ def get_plot_data_ols(self, *args: Any, **kwargs: Any) -> pd.DataFrame:
141141

142142
def effect_summary(
143143
self,
144-
window: Union[Literal["post"], tuple, slice] = "post",
144+
window: Literal["post"] | tuple | slice = "post",
145145
direction: Literal["increase", "decrease", "two-sided"] = "increase",
146146
alpha: float = 0.05,
147147
cumulative: bool = True,

causalpy/experiments/diff_in_diff.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
Difference in differences
1616
"""
1717

18-
from typing import Union
19-
2018
import arviz as az
2119
import numpy as np
2220
import pandas as pd
@@ -98,7 +96,7 @@ def __init__(
9896
time_variable_name: str,
9997
group_variable_name: str,
10098
post_treatment_variable_name: str = "post_treatment",
101-
model: Union[PyMCModel, RegressorMixin] | None = None,
99+
model: PyMCModel | RegressorMixin | None = None,
102100
**kwargs: dict,
103101
) -> None:
104102
super().__init__(model=model)
@@ -234,7 +232,7 @@ def __init__(
234232
elif isinstance(self.model, RegressorMixin):
235233
# This is the coefficient on the interaction term
236234
# Store the coefficient into dictionary {intercept:value}
237-
coef_map = dict(zip(self.labels, self.model.get_coeffs()))
235+
coef_map = dict(zip(self.labels, self.model.get_coeffs(), strict=False))
238236
# Create and find the interaction term based on the values user provided
239237
interaction_term = (
240238
f"{self.group_variable_name}:{self.post_treatment_variable_name}"

causalpy/experiments/instrumental_variable.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ def get_naive_OLS_fit(self) -> None:
238238
ols_reg = sk_lin_reg().fit(self.X, self.y)
239239
beta_params = list(ols_reg.coef_[0][1:])
240240
beta_params.insert(0, ols_reg.intercept_[0])
241-
self.ols_beta_params = dict(zip(self._x_design_info.column_names, beta_params))
241+
self.ols_beta_params = dict(
242+
zip(self._x_design_info.column_names, beta_params, strict=False)
243+
)
242244
self.ols_reg = ols_reg
243245

244246
def plot(self, *args, **kwargs) -> None: # type: ignore[override]

0 commit comments

Comments
 (0)