Skip to content

Commit 967b1a6

Browse files
authored
Merge branch 'main' into kgao/drscore
2 parents 15f4c36 + 0be1625 commit 967b1a6

File tree

8 files changed

+18
-11
lines changed

8 files changed

+18
-11
lines changed

.github/workflows/ci.yml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ jobs:
124124
- run: sudo apt-get -yq install graphviz
125125
name: Install graphviz
126126
if: ${{ matrix.install_graphviz }}
127-
- run: pip install -e .${{ matrix.extras }}
127+
# Add verbose flag to pip installation if in debug mode
128+
- run: pip install -e .${{ matrix.extras }} ${{ fromJSON('["","-v"]')[runner.debug] }}
128129
name: Install econml
129130
- run: pip install pytest pytest-runner jupyter jupyter-client nbconvert nbformat seaborn xgboost tqdm
130131
name: Install test and notebook requirements
@@ -183,7 +184,8 @@ jobs:
183184
python-version: ${{ matrix.python-version }}
184185
- run: python -m pip install --upgrade pip && pip install --upgrade setuptools
185186
name: Ensure latest pip and setuptools
186-
- run: pip install -e .${{ matrix.extras }}
187+
# Add verbose flag to pip installation if in debug mode
188+
- run: pip install -e .${{ matrix.extras }} ${{ fromJSON('["","-v"]')[runner.debug] }}
187189
name: Install econml
188190
- run: pip install pytest pytest-runner coverage
189191
name: Install pytest
@@ -224,7 +226,7 @@ jobs:
224226
steps:
225227
- run: exit 1
226228
name: At least one check failed or was cancelled
227-
if: ${{ !(success()) }}
229+
if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') }}
228230
- run: exit 0
229231
name: All checks passed
230-
if: ${{ success() }}
232+
if: ${{ !(contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled')) }}

econml/_ensemble/_ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -158,8 +158,8 @@ def _partition_estimators(n_estimators, n_jobs):
158158

159159
# Partition estimators between jobs
160160
n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs,
161-
dtype=np.int)
162-
n_estimators_per_job[:n_estimators % n_jobs] += 1
161+
dtype=int)
162+
n_estimators_per_job[: n_estimators % n_jobs] += 1
163163
starts = np.cumsum(n_estimators_per_job)
164164

165165
return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist()

econml/_ortho_learner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -899,7 +899,7 @@ def score(self, Y, T, X=None, W=None, Z=None, sample_weight=None, groups=None):
899899
nuisances = [np.zeros((n_iters * n_splits,) + nuis.shape) for nuis in nuisance_temp]
900900

901901
for it, nuis in enumerate(nuisance_temp):
902-
nuisances[it][i * n_iters + j] = nuis
902+
nuisances[it][j * n_iters + i] = nuis
903903

904904
for it in range(len(nuisances)):
905905
nuisances[it] = np.mean(nuisances[it], axis=0)

econml/data/dynamic_panel_dgp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def simulate_residuals(ind):
135135

136136

137137
def simulate_residuals_all(res_df):
138-
res_df_new = res_df.copy(deep=True)
138+
res_df_new = res_df.astype(dtype='float64', copy=True, errors='raise')
139139
for i in range(res_df.shape[1]):
140140
res_df_new.iloc[:, i] = simulate_residuals(i)
141141
# demean the new residual again

econml/tests/test_dml.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,7 @@ def test_nuisance_scores(self):
10951095
est.fit(y, T, X=X, W=W)
10961096
assert len(est.nuisance_scores_t) == len(est.nuisance_scores_y) == mc_iters
10971097
assert len(est.nuisance_scores_t[0]) == len(est.nuisance_scores_y[0]) == cv
1098+
est.score(y, T, X=X, W=W)
10981099

10991100
def test_categories(self):
11001101
dmls = [LinearDML, SparseLinearDML]

notebooks/Solutions/Causal Interpretation for Ames Housing Price.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@
598598
"X = Xy.drop(columns = 'SalePrice')\n",
599599
"X_ohe = (\n",
600600
" X\n",
601-
" .pipe(pd.get_dummies, prefix_sep = '_OHE_', columns = categorical)\n",
601+
" .pipe(pd.get_dummies, prefix_sep = '_OHE_', columns = categorical, dtype='uint8')\n",
602602
")\n",
603603
"y = Xy['SalePrice']"
604604
]

notebooks/Solutions/Causal Interpretation for Employee Attrition Dataset.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@
432432
"outputs": [],
433433
"source": [
434434
"categorical = []\n",
435-
"for col, value in attritionXData.iteritems():\n",
435+
"for col, value in attritionXData.items():\n",
436436
" if value.dtype == \"object\":\n",
437437
" categorical.append(col)\n",
438438
"\n",

setup.cfg

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ install_requires =
3939
joblib >= 0.13.0
4040
statsmodels >= 0.10
4141
pandas
42-
shap >= 0.38.1, < 0.41.0
42+
shap >= 0.38.1, < 0.42.0
4343
lightgbm
4444
test_suite = econml.tests
4545
tests_require =
@@ -58,6 +58,8 @@ tf =
5858
tensorflow > 1.10, < 2.3;python_version < '3.9'
5959
; Version capped due to tensorflow incompatibility
6060
protobuf < 4
61+
; Version capped due to tensorflow incompatibility
62+
numpy < 1.24
6163
plt =
6264
graphviz
6365
; Version capped due to shap incompatibility
@@ -70,6 +72,8 @@ all =
7072
tensorflow > 1.10, < 2.3
7173
; Version capped due to tensorflow incompatibility
7274
protobuf < 4
75+
; Version capped due to tensorflow incompatibility
76+
numpy < 1.24
7377
; Version capped due to shap incompatibility
7478
matplotlib < 3.6.0
7579
dowhy < 0.9

0 commit comments

Comments
 (0)