Skip to content

Commit 35df4eb

Browse files
author
Jaime Céspedes Sisniega
authored
Merge pull request #317 from IFCA-Advanced-Computing/feature-save-load
Add save and load utils functions
2 parents cfb5e4a + 185ddc8 commit 35df4eb

File tree

9 files changed

+577
-0
lines changed

9 files changed

+577
-0
lines changed

docs/source/api_reference/utils.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ The {mod}`frouros.utils` module contains auxiliary classes, functions or excepti
88
utils/checks
99
utils/data_structures
1010
utils/kernels
11+
utils/persistence
1112
utils/stats
1213
```
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Persistence
2+
3+
The {mod}`frouros.utils.persistence` module contains auxiliary functions to persistence objects.
4+
5+
```{eval-rst}
6+
.. automodule:: frouros.utils.persistence
7+
:members:
8+
:no-inherited-members:
9+
```

docs/source/examples.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
66
examples/concept_drift
77
examples/data_drift
8+
examples/utils
89
```

docs/source/examples/utils.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Utils
2+
3+
```{toctree}
4+
:maxdepth: 1
5+
6+
utils/save_load
7+
```
Lines changed: 359 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,359 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"id": "initial_id",
7+
"metadata": {
8+
"collapsed": true,
9+
"ExecuteTime": {
10+
"end_time": "2024-03-02T20:08:36.559538Z",
11+
"start_time": "2024-03-02T20:08:35.785936Z"
12+
}
13+
},
14+
"outputs": [],
15+
"source": [
16+
"from functools import partial\n",
17+
"import numpy as np\n",
18+
"from scipy.spatial.distance import pdist\n",
19+
"\n",
20+
"from frouros.callbacks import PermutationTestDistanceBased\n",
21+
"from frouros.detectors.data_drift import MMD\n",
22+
"from frouros.utils import load, save\n",
23+
"from frouros.utils.kernels import rbf_kernel"
24+
]
25+
},
26+
{
27+
"cell_type": "markdown",
28+
"source": [
29+
"# Save and Load detector\n",
30+
"\n",
31+
"In this example, we will demonstrate how to save and load a detector. We will use the MMD detector and the permutation test callback. We will first fit the detector and then compare two datasets. We will then save the detector to a file and load it back. We will then compare the same two datasets and assert that the distance and p-value are the same before and after saving and loading the detector."
32+
],
33+
"metadata": {
34+
"collapsed": false
35+
},
36+
"id": "e3f1ddf0540a9259"
37+
},
38+
{
39+
"cell_type": "markdown",
40+
"source": [
41+
"## Set random seed\n",
42+
"\n",
43+
"We will set the random seed to ensure reproducibility."
44+
],
45+
"metadata": {
46+
"collapsed": false
47+
},
48+
"id": "4df73e55d7d353bb"
49+
},
50+
{
51+
"cell_type": "code",
52+
"outputs": [],
53+
"source": [
54+
"seed = 31\n",
55+
"np.random.seed(seed)"
56+
],
57+
"metadata": {
58+
"collapsed": false,
59+
"ExecuteTime": {
60+
"end_time": "2024-03-02T20:08:36.567956Z",
61+
"start_time": "2024-03-02T20:08:36.561066Z"
62+
}
63+
},
64+
"id": "f913c4fc44d511f7",
65+
"execution_count": 2
66+
},
67+
{
68+
"cell_type": "markdown",
69+
"source": [
70+
"## Generate data\n",
71+
"\n",
72+
"We will generate two datasets. The first dataset will be generated from a multivariate normal distribution with mean [0, 0] and covariance matrix [[1, 0], [0, 1]]. The second dataset will be generated from a multivariate normal distribution with mean [1, 0] and covariance matrix [[1, 0], [0, 2]]."
73+
],
74+
"metadata": {
75+
"collapsed": false
76+
},
77+
"id": "b08089f5ccf0f4d1"
78+
},
79+
{
80+
"cell_type": "code",
81+
"outputs": [],
82+
"source": [
83+
"num_samples = 100\n",
84+
"\n",
85+
"x_mean = [0, 0]\n",
86+
"x_cov = [\n",
87+
" [1, 0],\n",
88+
" [0, 1],\n",
89+
"]\n",
90+
"\n",
91+
"y_mean = [1, 0]\n",
92+
"y_cov = [\n",
93+
" [1, 0],\n",
94+
" [0, 2],\n",
95+
"]\n",
96+
"\n",
97+
"X_ref = np.random.multivariate_normal(\n",
98+
" mean=x_mean,\n",
99+
" cov=x_cov,\n",
100+
" size=num_samples,\n",
101+
")\n",
102+
"X_test = np.random.multivariate_normal(\n",
103+
" mean=y_mean,\n",
104+
" cov=y_cov,\n",
105+
" size=num_samples,\n",
106+
")"
107+
],
108+
"metadata": {
109+
"collapsed": false,
110+
"ExecuteTime": {
111+
"end_time": "2024-03-02T20:08:36.583840Z",
112+
"start_time": "2024-03-02T20:08:36.570122Z"
113+
}
114+
},
115+
"id": "188b82ee45c1a092",
116+
"execution_count": 3
117+
},
118+
{
119+
"cell_type": "markdown",
120+
"source": [
121+
"## Fit detector\n",
122+
"\n",
123+
"We will fit the detector using the reference dataset."
124+
],
125+
"metadata": {
126+
"collapsed": false
127+
},
128+
"id": "dd7dd35a96e1651a"
129+
},
130+
{
131+
"cell_type": "code",
132+
"outputs": [
133+
{
134+
"data": {
135+
"text/plain": "1.5941478725484344"
136+
},
137+
"execution_count": 4,
138+
"metadata": {},
139+
"output_type": "execute_result"
140+
}
141+
],
142+
"source": [
143+
"sigma = np.median(\n",
144+
" pdist(\n",
145+
" X=X_ref,\n",
146+
" metric=\"euclidean\",\n",
147+
" ),\n",
148+
" )\n",
149+
"sigma"
150+
],
151+
"metadata": {
152+
"collapsed": false,
153+
"ExecuteTime": {
154+
"end_time": "2024-03-02T20:08:36.599907Z",
155+
"start_time": "2024-03-02T20:08:36.584853Z"
156+
}
157+
},
158+
"id": "23fac866bcd656ee",
159+
"execution_count": 4
160+
},
161+
{
162+
"cell_type": "code",
163+
"outputs": [],
164+
"source": [
165+
"detector = MMD(\n",
166+
" kernel=partial(\n",
167+
" rbf_kernel,\n",
168+
" sigma=sigma,\n",
169+
" ),\n",
170+
" callbacks=PermutationTestDistanceBased(\n",
171+
" num_permutations=100,\n",
172+
" num_jobs=-1,\n",
173+
" method=\"exact\",\n",
174+
" random_state=seed,\n",
175+
" name=\"permutation_test\",\n",
176+
" ),\n",
177+
")\n",
178+
"\n",
179+
"_ = detector.fit(\n",
180+
" X=X_ref,\n",
181+
")"
182+
],
183+
"metadata": {
184+
"collapsed": false,
185+
"ExecuteTime": {
186+
"end_time": "2024-03-02T20:08:36.615923Z",
187+
"start_time": "2024-03-02T20:08:36.603076Z"
188+
}
189+
},
190+
"id": "3bf7b070454ba708",
191+
"execution_count": 5
192+
},
193+
{
194+
"cell_type": "markdown",
195+
"source": [
196+
"## Compare datasets before saving\n",
197+
"\n",
198+
"We will compare the reference and test datasets."
199+
],
200+
"metadata": {
201+
"collapsed": false
202+
},
203+
"id": "ca0bee617c055e14"
204+
},
205+
{
206+
"cell_type": "code",
207+
"outputs": [
208+
{
209+
"name": "stdout",
210+
"output_type": "stream",
211+
"text": [
212+
"Distance: 0.14644993, p-value: 0.00990049\n"
213+
]
214+
}
215+
],
216+
"source": [
217+
"distance, callback_logs = detector.compare(\n",
218+
" X=X_test,\n",
219+
")\n",
220+
"before_save_distance = distance.distance\n",
221+
"before_save_p_value = callback_logs['permutation_test']['p_value']\n",
222+
"print(f\"Distance: {before_save_distance:.8f}, p-value: {before_save_p_value:.8f}\")"
223+
],
224+
"metadata": {
225+
"collapsed": false,
226+
"ExecuteTime": {
227+
"end_time": "2024-03-02T20:08:39.021802Z",
228+
"start_time": "2024-03-02T20:08:36.616944Z"
229+
}
230+
},
231+
"id": "c1f670b30658a751",
232+
"execution_count": 6
233+
},
234+
{
235+
"cell_type": "markdown",
236+
"source": [
237+
"## Save and Load detector\n",
238+
"\n",
239+
"We will save the detector to a file and load it back."
240+
],
241+
"metadata": {
242+
"collapsed": false
243+
},
244+
"id": "4dad43da2f94c1ec"
245+
},
246+
{
247+
"cell_type": "code",
248+
"outputs": [],
249+
"source": [
250+
"save(\n",
251+
" obj=detector,\n",
252+
" filename=\"detector.pkl\",\n",
253+
")\n",
254+
"\n",
255+
"detector = load(\n",
256+
" filename=\"detector.pkl\",\n",
257+
")"
258+
],
259+
"metadata": {
260+
"collapsed": false,
261+
"ExecuteTime": {
262+
"end_time": "2024-03-02T20:08:39.037744Z",
263+
"start_time": "2024-03-02T20:08:39.024229Z"
264+
}
265+
},
266+
"id": "d0aa212a9e91de5c",
267+
"execution_count": 7
268+
},
269+
{
270+
"cell_type": "markdown",
271+
"source": [
272+
"## Compare datasets after loading\n",
273+
"\n",
274+
"We will compare the reference and test datasets again."
275+
],
276+
"metadata": {
277+
"collapsed": false
278+
},
279+
"id": "97d354f3aaf7f555"
280+
},
281+
{
282+
"cell_type": "code",
283+
"outputs": [
284+
{
285+
"name": "stdout",
286+
"output_type": "stream",
287+
"text": [
288+
"Distance: 0.14644993, p-value: 0.00990049\n"
289+
]
290+
}
291+
],
292+
"source": [
293+
"distance, callback_logs = detector.compare(\n",
294+
" X=X_test,\n",
295+
")\n",
296+
"after_save_distance = distance.distance\n",
297+
"after_save_p_value = callback_logs['permutation_test']['p_value']\n",
298+
"print(f\"Distance: {after_save_distance:.8f}, p-value: {after_save_p_value:.8f}\")"
299+
],
300+
"metadata": {
301+
"collapsed": false,
302+
"ExecuteTime": {
303+
"end_time": "2024-03-02T20:08:41.628646Z",
304+
"start_time": "2024-03-02T20:08:39.038798Z"
305+
}
306+
},
307+
"id": "a681537ba868af6b",
308+
"execution_count": 8
309+
},
310+
{
311+
"cell_type": "markdown",
312+
"source": [
313+
"Assert that the distance and p-value are the same before and after saving and loading the detector."
314+
],
315+
"metadata": {
316+
"collapsed": false
317+
},
318+
"id": "3a81841ec13cc881"
319+
},
320+
{
321+
"cell_type": "code",
322+
"outputs": [],
323+
"source": [
324+
"assert before_save_distance == after_save_distance\n",
325+
"assert before_save_p_value == after_save_p_value"
326+
],
327+
"metadata": {
328+
"collapsed": false,
329+
"ExecuteTime": {
330+
"end_time": "2024-03-02T20:08:41.644471Z",
331+
"start_time": "2024-03-02T20:08:41.629678Z"
332+
}
333+
},
334+
"id": "1a7e98cb985f2e5b",
335+
"execution_count": 9
336+
}
337+
],
338+
"metadata": {
339+
"kernelspec": {
340+
"display_name": "Python 3",
341+
"language": "python",
342+
"name": "python3"
343+
},
344+
"language_info": {
345+
"codemirror_mode": {
346+
"name": "ipython",
347+
"version": 2
348+
},
349+
"file_extension": ".py",
350+
"mimetype": "text/x-python",
351+
"name": "python",
352+
"nbconvert_exporter": "python",
353+
"pygments_lexer": "ipython2",
354+
"version": "2.7.6"
355+
}
356+
},
357+
"nbformat": 4,
358+
"nbformat_minor": 5
359+
}

0 commit comments

Comments
 (0)