Skip to content

Commit bc1ca87

Browse files
committed
calm init commit
1 parent 1c959cd commit bc1ca87

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+7147
-7
lines changed
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import numpy as np\n",
10+
"import matplotlib.pyplot as plt\n",
11+
"from scipy.special import expit # sigmoid for smooth saturation"
12+
]
13+
},
14+
{
15+
"cell_type": "markdown",
16+
"metadata": {},
17+
"source": [
18+
"# Concept Image - Figure 1"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {},
25+
"outputs": [],
26+
"source": [
27+
"\n",
28+
"# f1(x1 | x3)\n",
29+
"def f11(x1):\n",
30+
" return 1.2 * x1**3 + 0.1 * x1 # steeper growth early\n",
31+
"\n",
32+
"def f12(x1):\n",
33+
" return 0.9 * np.tanh(2 * x1) # saturates later\n",
34+
"\n",
35+
"def f1(x1, x3):\n",
36+
" y = np.zeros_like(x1)\n",
37+
"\n",
38+
" cond = x3 > 0\n",
39+
" y[cond] = f11(x1[cond])\n",
40+
" y[~cond] = f11(x1[~cond])\n",
41+
" return y\n",
42+
"\n",
43+
"def f21(x2):\n",
44+
" return 0.4 * np.sin(np.pi * x2 / 2 - 0.2) - 0.4 # same shape, lower offset\n",
45+
"\n",
46+
"def f22(x2):\n",
47+
" return 0.6 * np.sin(np.pi * x2 / 2 - 0.1) - 0.2 # same shape, lower offset\n",
48+
"\n",
49+
"def f23(x2):\n",
50+
" return 0.8 * np.sin(np.pi * x2 / 2) # natural bell-like response\n",
51+
"\n",
52+
"def f2(x2, x1):\n",
53+
" y = np.zeros_like(x2)\n",
54+
"\n",
55+
" cond1 = x1 < - 0.4\n",
56+
" cond2 = np.logical_and(x1 >= -0.4, x1 < 0.4 )\n",
57+
" cond3 = x1 >= 0.4\n",
58+
" y[cond1] = f21(x2[cond1])\n",
59+
" y[cond2] = f22(x2[cond2])\n",
60+
" y[cond3] = f23(x2[cond3])\n",
61+
" return y\n",
62+
"\n",
63+
"\n",
64+
"def f3 (x3):\n",
65+
" return 1.5 * (expit(2 * x3) - 0.5) # maps x3 in [-1,1] to roughly [-0.75, 0.75]\n",
66+
"\n",
67+
"x = np.linspace(-1, 1, 200)\n",
68+
"folder = \"test_synth_user_study/\"\n",
69+
"import os \n",
70+
"os.makedirs(folder, exist_ok=True)\n",
71+
"import matplotlib.ticker as ticker\n",
72+
"fontsize=17\n",
73+
"# --- Figure 1: f1(x1 | x3) ---\n",
74+
"plt.figure()\n",
75+
"plt.plot(x, f11(x), label=r\"$f(x_1 \\mid x_3 > 0)$\")\n",
76+
"plt.plot(x, f12(x), label=r\"$f(x_1 \\mid x_3 \\leq 0)$\")\n",
77+
"\n",
78+
"plt.axvline(x=-0.4, color='gray', linestyle=':')\n",
79+
"plt.axvline(x=0.4, color='gray', linestyle=':')\n",
80+
"plt.text(-0.4, -0.9, r'$x_2\\ (\\uparrow):[0,0.37]$', ha='center', va='top', fontsize=fontsize)\n",
81+
"plt.text(0.4, -0.9, r'$x_2\\ (\\uparrow):[0,0.37]$', ha='center', va='top', fontsize=fontsize)\n",
82+
"\n",
83+
"plt.xticks([])\n",
84+
"plt.yticks([])\n",
85+
"# plt.ylim(-1.5, 1.5)\n",
86+
"plt.legend(fontsize=fontsize,)\n",
87+
"plt.tight_layout()\n",
88+
"ax = plt.gca()\n",
89+
"for spine in ax.spines.values():\n",
90+
" spine.set_visible(False)\n",
91+
"\n",
92+
"plt.show()\n",
93+
"\n",
94+
"# --- Figure 2: f2(x2 | x1) ---\n",
95+
"plt.figure()\n",
96+
"plt.plot(x, f21(x), label=r\"$f(x_2 \\mid x_1 < -0.4)$\")\n",
97+
"plt.plot(x, f22(x), label=r\"$f(x_2 \\mid x_1 \\in [-0.4, 0.4])$\")\n",
98+
"plt.plot(x, f23(x), label=r\"$f(x_2 \\mid x_1 \\geq 0.4)$\")\n",
99+
"\n",
100+
"plt.xticks([])\n",
101+
"plt.yticks([])\n",
102+
"plt.legend(fontsize=fontsize,)\n",
103+
"plt.tight_layout()\n",
104+
"ax = plt.gca()\n",
105+
"for spine in ax.spines.values():\n",
106+
" spine.set_visible(False)\n",
107+
"\n",
108+
"ax = plt.gca()\n",
109+
"\n",
110+
"\n",
111+
"plt.show()\n",
112+
"\n",
113+
"# --- Figure 3: f3(x3) ---\n",
114+
"plt.figure()\n",
115+
"plt.plot(x, f3(x), label=r\"$f_d(x_d)$\")\n",
116+
"plt.axvline(x=0, color='gray', linestyle=':')\n",
117+
"plt.text(0, -0.4, r'$x_1\\ (\\updownarrow)[-0.42,0,42]$', ha='center', va='top', fontsize=fontsize)\n",
118+
"\n",
119+
"plt.xticks([])\n",
120+
"plt.yticks([])\n",
121+
"# plt.ylim(-1.5, 1.5)\n",
122+
"plt.legend(fontsize=fontsize,)\n",
123+
"plt.tight_layout()\n",
124+
"ax = plt.gca()\n",
125+
"for spine in ax.spines.values():\n",
126+
" spine.set_visible(False)\n",
127+
"\n",
128+
"plt.show()"
129+
]
130+
},
131+
{
132+
"cell_type": "markdown",
133+
"metadata": {},
134+
"source": [
135+
"The regions are defined conditioning on the interacting features:\n",
136+
"* Τhe effect of x1 conditions on x3 (Cx3 )\n",
137+
"* he effect of x2 conditions on x1 (Cx1 )\n",
138+
"* xd does not interact with any other feature and thus has a single plot"
139+
]
140+
},
141+
{
142+
"cell_type": "markdown",
143+
"metadata": {},
144+
"source": [
145+
"# Figure 2: CALM plot for x1"
146+
]
147+
},
148+
{
149+
"cell_type": "code",
150+
"execution_count": null,
151+
"metadata": {},
152+
"outputs": [],
153+
"source": [
154+
"fontsize = 15\n",
155+
"# --- Figure 1: f1(x1 | x3) ---\n",
156+
"plt.figure()\n",
157+
"plt.plot(x, f11(x), label=r\"$f(x_1 \\mid x_3 > 0)$\")\n",
158+
"plt.plot(x, f12(x), label=r\"$f(x_1 \\mid x_3 \\leq 0)$\")\n",
159+
"\n",
160+
"plt.axvline(x=-0.4, color='gray', linestyle=':')\n",
161+
"plt.axvline(x=0.4, color='gray', linestyle=':')\n",
162+
"plt.text(-0.4, -0.9, r'$x_2\\ (\\uparrow):[0,0.37]$', ha='center', va='top', fontsize=fontsize)\n",
163+
"plt.text(0.4, -0.9, r'$x_2\\ (\\uparrow):[0,0.37]$', ha='center', va='top', fontsize=fontsize)\n",
164+
"\n",
165+
"plt.xlabel(r\"$x_1$\", fontsize=fontsize)\n",
166+
"plt.ylabel(r\"$y$\", fontsize=fontsize)\n",
167+
"plt.xticks([-1, -.5, 0, 0.5, 1])\n",
168+
"plt.yticks([-1.5, -.75, 0, .75, 1.5])\n",
169+
"plt.legend(fontsize=fontsize)\n",
170+
"plt.tight_layout()\n",
171+
"\n",
172+
"ax = plt.gca()\n",
173+
"ax.tick_params(axis='both', labelsize=fontsize-1)\n",
174+
"\n",
175+
"plt.show()\n"
176+
]
177+
},
178+
{
179+
"cell_type": "markdown",
180+
"metadata": {},
181+
"source": [
182+
"Each curve gives the contribution of $ x_1 $ to $ y $ in a specific region.\n",
183+
"* The blue curve when $ x_3 > 0 $ \n",
184+
"* The orange curve when $ x_3 \\leq 0 $\n",
185+
"\n",
186+
"For example, at $ x_1 = -0.5 $, the contribution is approximately $-0.2$ (blue) or $-0.75$ (orange), depending on $ x_3 $.\n",
187+
"The plots also illustrate how altering $ x_1 $ to $x_1 \\rightarrow x_1 + \\delta$ impacts the prediction.\n",
188+
"Vertical dotted lines mark points of a hidden discontinuity which is due to $x_1$ participating as an interaction term for feature $x_2$. As shown in previous image, the effect of $ x_2 $ is conditioned by $x_1 \\leq - 0.4 $, $-0.4 \\leq x_1 \\leq 0.4 $ and $x_1 > 0.4$, therefore in this figurer we observe vertical lines in $x_1 \\pm 0.4$.\n",
189+
"If a change in $x_1$ does not cross a vertical line, the change in the output $ (\\Delta y) $ equals the curve difference $ (\\Delta f_i )$.\n",
190+
"Crossing a line signifies a hidden jump, in the range $[\\alpha, \\beta]$, so \n",
191+
"$ \\Delta f_i + a \\leq \\Delta y \\leq \\Delta f_i + \\beta$.\n",
192+
"\n",
193+
"Arrows provide a fast understanding of the jump:\n",
194+
"* $\\uparrow$ means $\\Delta y > \\Delta f_i$,\n",
195+
"* $\\downarrow$ means $\\Delta y < \\Delta f_i$,\n",
196+
"* $\\updownarrow$ means it depends."
197+
]
198+
}
199+
],
200+
"metadata": {
201+
"kernelspec": {
202+
"display_name": "CALM-ENV",
203+
"language": "python",
204+
"name": "calm-env"
205+
},
206+
"language_info": {
207+
"codemirror_mode": {
208+
"name": "ipython",
209+
"version": 3
210+
},
211+
"file_extension": ".py",
212+
"mimetype": "text/x-python",
213+
"name": "python",
214+
"nbconvert_exporter": "python",
215+
"pygments_lexer": "ipython3",
216+
"version": "3.10.16"
217+
}
218+
},
219+
"nbformat": 4,
220+
"nbformat_minor": 2
221+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
task: classification
2+
stop_on_error: false
3+
random_seed: 42
4+
test_size: 0.2
5+
output_file: results_classification.csv
6+
append_to_file: true
7+
kfold_n_splits: 5
8+
9+
datasets:
10+
- name: Adult
11+
module: effector.calm.datasets.adult
12+
- name: COMPAS
13+
module: effector.calm.datasets.compas
14+
- name: HELOC
15+
module: effector.calm.datasets.heloc
16+
- name: MIMIC2
17+
module: effector.calm.datasets.mimic2
18+
- name: PMLB_APPENDICITIS
19+
module: effector.calm.datasets.pmlb
20+
- name: PMLB_PHONEME
21+
module: effector.calm.datasets.pmlb
22+
- name: PMLB_SPECTF
23+
module: effector.calm.datasets.pmlb
24+
- name: Magic
25+
module: effector.calm.datasets.magic
26+
- name: Bank
27+
module: effector.calm.datasets.bank
28+
- name: PMLB_CHURN
29+
module: effector.calm.datasets.pmlb
30+
31+
methods:
32+
- name: DNNClassifier
33+
type: blackbox
34+
- name: RFClassifier
35+
type: blackbox
36+
- name: XGBClassifier
37+
type: blackbox
38+
- name: MaskedNAMClassifier
39+
type: maskedgam
40+
parameters: {}
41+
- name: PyGAMClassifier
42+
type: maskedgam
43+
parameters: {}
44+
- name: NoInteractionsEBMClassifier
45+
type: maskedgam
46+
parameters: {}
47+
- name: CALMClassifier
48+
type: calm
49+
parameters:
50+
region_detector: ["RegionalPDP", "RegionalRHALE"]
51+
masked_gam_name: ["NoInteractionsEBMClassifier", "PyGAMClassifier", "MaskedNAMClassifier"]
52+
blackbox_model: ["DNNClassifier", "RFClassifier", "XGBClassifier"]
53+
- name: EBM2Classifier
54+
type: competitor
55+
- name: NodeGAM2Classifier
56+
type: competitor
57+
parameters:
58+
max_time: [300]
59+
- name: GAMINetClassifier
60+
type: competitor
61+
62+
metrics: [accuracy, balanced_accuracy, f1]
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
task: classification
2+
stop_on_error: false
3+
random_seed: 42
4+
test_size: 0.2
5+
output_file: results_classification_main.csv
6+
append_to_file: true
7+
kfold_n_splits: 1
8+
9+
datasets:
10+
- name: Adult
11+
module: effector.calm.datasets.adult
12+
- name: COMPAS
13+
module: effector.calm.datasets.compas
14+
- name: HELOC
15+
module: effector.calm.datasets.heloc
16+
- name: MIMIC2
17+
module: effector.calm.datasets.mimic2
18+
- name: PMLB_APPENDICITIS
19+
module: effector.calm.datasets.pmlb
20+
- name: PMLB_PHONEME
21+
module: effector.calm.datasets.pmlb
22+
- name: PMLB_SPECTF
23+
module: effector.calm.datasets.pmlb
24+
- name: Magic
25+
module: effector.calm.datasets.magic
26+
- name: Bank
27+
module: effector.calm.datasets.bank
28+
- name: PMLB_CHURN
29+
module: effector.calm.datasets.pmlb
30+
31+
methods:
32+
- name: XGBClassifier
33+
type: blackbox
34+
- name: MaskedNAMClassifier
35+
type: maskedgam
36+
parameters: {}
37+
- name: NoInteractionsEBMClassifier
38+
type: maskedgam
39+
parameters: {}
40+
- name: CALMClassifier
41+
type: calm
42+
- name: EBM2Classifier
43+
type: competitor
44+
- name: NodeGAM2Classifier
45+
type: competitor
46+
parameters:
47+
max_time: [300]
48+
- name: GAMINetClassifier
49+
type: competitor
50+
51+
metrics: [accuracy, balanced_accuracy, f1]

0 commit comments

Comments
 (0)