Skip to content

Commit 02f0745

Browse files
committed
Update IOW notebook with basic model training
1 parent 16dbc6b commit 02f0745

File tree

1 file changed

+270
-13
lines changed

1 file changed

+270
-13
lines changed

04_internet_of_wands/internet_of_wands.ipynb

Lines changed: 270 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@
198198
},
199199
"outputs": [],
200200
"source": [
201-
"!wget \"https://github.com/pablodecm/datalab_ml_iot/raw/master/04_internet_of_wands/iow_data.zip\"; unzip -o iow_data.zip"
201+
"!wget \"https://github.com/pablodecm/datalab_ml_iot/raw/master/04_internet_of_wands/iow_data.zip\"; unzip -qq -o iow_data.zip"
202202
]
203203
},
204204
{
@@ -235,7 +235,7 @@
235235
"import matplotlib.pyplot as plt\n",
236236
"import warnings\n",
237237
"from sklearn.model_selection import train_test_split\n",
238-
"#warnings.filterwarnings(\"ignore\")\n",
238+
"warnings.filterwarnings(\"ignore\")\n",
239239
"\n",
240240
"example_file = \"iow_data/wingardium-leviosa/Peppapig_9b2bd7a9.0696f8.json\"\n",
241241
"\n",
@@ -331,6 +331,7 @@
331331
"metadata": {},
332332
"outputs": [],
333333
"source": [
334+
"# we can not plot the resampled sensor data\n",
334335
"fig, axs = plt.subplots(2, figsize=(12,12))\n",
335336
"\n",
336337
"merged_df.filter(regex=\"._accel\").plot(ax=axs[0])\n",
@@ -401,6 +402,7 @@
401402
"metadata": {},
402403
"outputs": [],
403404
"source": [
405+
"# and apply fo all dat\n",
404406
"md_fields = [\"spell_select\",\"device_select\",\"wizard_name\"]\n",
405407
"data_path = Path(\"./iow_data\")\n",
406408
"\n",
@@ -424,6 +426,7 @@
424426
"metadata": {},
425427
"outputs": [],
426428
"source": [
429+
"# this is dataframe with a multi-index (spell_select, device_select, wizard_name, spell_id, timestamp)\n",
427430
"all_df = pd.concat(merged_df_dict, names = (md_fields+ [\"spell_id\"]))\n",
428431
"all_df"
429432
]
@@ -445,7 +448,8 @@
445448
"\n",
446449
"var_name = \"y_accel\"\n",
447450
"wizard_name = \"pablodecm\"\n",
448-
"spell_name = \"reparo\"\n",
451+
"spell_name = \"alohomora\"\n",
452+
"\n",
449453
"subset_df = all_df.loc[(spell_name,slice(None),wizard_name),var_name]\n",
450454
"\n",
451455
"\n",
@@ -518,7 +522,7 @@
518522
"metadata": {},
519523
"outputs": [],
520524
"source": [
521-
"# further exploratory data analysis"
525+
"# feel free to carry out further exploratory data analysis"
522526
]
523527
},
524528
{
@@ -534,7 +538,7 @@
534538
"metadata": {},
535539
"outputs": [],
536540
"source": [
537-
"\n",
541+
"# cleaning a little bit the data\n",
538542
"# get only spells that last more than 400 ms\n",
539543
"long_spells = all_df.groupby(\"spell_id\").count() > 20\n",
540544
"long_spells = long_spells[long_spells].dropna().index"
@@ -552,6 +556,24 @@
552556
"valid_df = all_df.loc[(slice(None),slice(None),slice(None),list(valid_subset)),:]"
553557
]
554558
},
559+
{
560+
"cell_type": "code",
561+
"execution_count": null,
562+
"metadata": {},
563+
"outputs": [],
564+
"source": [
565+
"train_df."
566+
]
567+
},
568+
{
569+
"cell_type": "code",
570+
"execution_count": null,
571+
"metadata": {},
572+
"outputs": [],
573+
"source": [
574+
"valid_df"
575+
]
576+
},
555577
{
556578
"cell_type": "markdown",
557579
"metadata": {},
@@ -577,8 +599,8 @@
577599
"metadata": {},
578600
"outputs": [],
579601
"source": [
580-
"# use the training dataframe to create a model\n",
581-
"train_df"
602+
"# a single row (spell) is characterized by 6 sensor recording and variable number of timesteps\n",
603+
"valid_df.loc[(slice(None),slice(None),slice(None),\"lumos/Voldemort_525eda85.587874\")].reset_index(drop=True).plot()"
582604
]
583605
},
584606
{
@@ -587,16 +609,251 @@
587609
"metadata": {},
588610
"outputs": [],
589611
"source": [
590-
"# evaluate the model in the validation dataset\n"
612+
"# the raw features are the following, however we have the problem that their number is variable\n",
613+
"features = [ f\"{i}_accel\" for i in [\"x\", \"y\",\"z\"]] + [ f\"{i}_gyro\" for i in [\"x\", \"y\",\"z\"]]\n",
614+
"features"
591615
]
592616
},
593617
{
594-
"cell_type": "markdown",
618+
"cell_type": "code",
619+
"execution_count": null,
620+
"metadata": {},
621+
"outputs": [],
622+
"source": [
623+
"# we can use initially the mean of each wavelenght as a feature\n",
624+
"train_df_mean = train_df.loc[(slice(None),slice(None),slice(None),slice(None)),:].groupby(\"spell_id\").mean().loc[:, features]\n",
625+
"train_df_mean.columns = [f\"{f}_mean\" for f in features]\n",
626+
"valid_df_mean= valid_df.loc[(slice(None),slice(None),slice(None),slice(None)),:].groupby(\"spell_id\").mean().loc[:, features]\n",
627+
"valid_df_mean.columns = [f\"{f}_mean\" for f in features]"
628+
]
629+
},
630+
{
631+
"cell_type": "code",
632+
"execution_count": null,
633+
"metadata": {},
634+
"outputs": [],
635+
"source": [
636+
"# we can optionally complement with the standard deviation as features\n",
637+
"train_df_std = train_df.loc[(slice(None),slice(None),slice(None),slice(None)),:].groupby(\"spell_id\").std().loc[:, features]\n",
638+
"train_df_std.columns = [f\"{f}_std\" for f in features]\n",
639+
"valid_df_std = valid_df.loc[(slice(None),slice(None),slice(None),slice(None)),:].groupby(\"spell_id\").std().loc[:, features]\n",
640+
"valid_df_std.columns = [f\"{f}_std\" for f in features]\n"
641+
]
642+
},
643+
{
644+
"cell_type": "code",
645+
"execution_count": null,
646+
"metadata": {},
647+
"outputs": [],
648+
"source": [
649+
"# example of features\n",
650+
"train_df_std.head()"
651+
]
652+
},
653+
{
654+
"cell_type": "code",
655+
"execution_count": null,
656+
"metadata": {},
657+
"outputs": [],
658+
"source": [
659+
"# we can create a simplified training dataset by contatenating both\n",
660+
"train_df_extra = pd.concat([train_df_mean,train_df_std], axis=1)\n",
661+
"valid_df_extra = pd.concat([valid_df_mean,valid_df_std], axis=1)"
662+
]
663+
},
664+
{
665+
"cell_type": "code",
666+
"execution_count": null,
667+
"metadata": {},
668+
"outputs": [],
669+
"source": [
670+
"# the training data will look like this\n",
671+
"train_df_extra.head()"
672+
]
673+
},
674+
{
675+
"cell_type": "code",
676+
"execution_count": null,
677+
"metadata": {},
678+
"outputs": [],
679+
"source": [
680+
"# we can also obtain the category label from the spell_id index\n",
681+
"label_assign = { \"alohomora\" : 0, \"lumos\" : 1, \"wingardium-leviosa\" : 2, \"reparo\" : 3}\n",
682+
"train_y = train_df_extra.reset_index().spell_id.str.split(\"/\").apply(lambda x: label_assign[x[0]])\n",
683+
"valid_y = valid_df_extra.reset_index().spell_id.str.split(\"/\").apply(lambda x: label_assign[x[0]])"
684+
]
685+
},
686+
{
687+
"cell_type": "code",
688+
"execution_count": null,
689+
"metadata": {},
690+
"outputs": [],
691+
"source": [
692+
"from sklearn.model_selection import KFold, GridSearchCV\n",
693+
"from sklearn.ensemble import GradientBoostingClassifier\n",
694+
"\n",
695+
"gb_clf = GradientBoostingClassifier()\n"
696+
]
697+
},
698+
{
699+
"cell_type": "code",
700+
"execution_count": null,
701+
"metadata": {},
702+
"outputs": [],
703+
"source": [
704+
"# we can train the classifier\n",
705+
"gb_clf.fit(train_df_extra, train_y)"
706+
]
707+
},
708+
{
709+
"cell_type": "code",
710+
"execution_count": null,
711+
"metadata": {},
712+
"outputs": [],
713+
"source": [
714+
"from sklearn.metrics import roc_auc_score, accuracy_score, precision_score, recall_score, confusion_matrix\n",
715+
"from sklearn.metrics import classification_report"
716+
]
717+
},
718+
{
719+
"cell_type": "code",
720+
"execution_count": null,
721+
"metadata": {},
722+
"outputs": [],
723+
"source": [
724+
"# we can get probability predictions\n",
725+
"gb_clf.predict_proba(train_df_extra)"
726+
]
727+
},
728+
{
729+
"cell_type": "code",
730+
"execution_count": null,
731+
"metadata": {},
732+
"outputs": [],
733+
"source": [
734+
"# and compute some metrics on the training datataset (not to be trusted)\n",
735+
"y_train_clf_proba = gb_clf.predict_proba(train_df_extra)[:, 1]\n",
736+
"y_train_clf_pred = gb_clf.predict(train_df_extra)\n",
737+
"\n",
738+
"print(\"Confusion Matrix:\")\n",
739+
"print(confusion_matrix(train_y,y_train_clf_pred))\n",
740+
"print(\"Gradient Boosting Classifier Accuracy: \"+\"{:.1%}\".format(accuracy_score(train_y,y_train_clf_pred)));\n",
741+
"print(\"Classification Report:\")\n",
742+
"print(classification_report(train_y,y_train_clf_pred))"
743+
]
744+
},
745+
{
746+
"cell_type": "code",
747+
"execution_count": null,
748+
"metadata": {},
749+
"outputs": [],
750+
"source": [
751+
"y_valid_clf_proba = gb_clf.predict_proba(valid_df_extra)[:, 1]\n",
752+
"y_valid_clf_pred = gb_clf.predict(valid_df_extra)\n",
753+
"\n",
754+
"print(\"Confusion Matrix:\")\n",
755+
"print(confusion_matrix(valid_y,y_valid_clf_pred))\n",
756+
"print(\"Gradient Boosting Classifier Accuracy: \"+\"{:.1%}\".format(accuracy_score(valid_y,y_valid_clf_pred)));\n",
757+
"print(\"Classification Report:\")\n",
758+
"print(classification_report(valid_y,y_valid_clf_pred))"
759+
]
760+
},
761+
{
762+
"cell_type": "code",
763+
"execution_count": null,
764+
"metadata": {},
765+
"outputs": [],
766+
"source": [
767+
"# can try to do a grid search to try to find a better hyper-parameter combination\n",
768+
"from sklearn.model_selection import KFold, GridSearchCV\n",
769+
"from sklearn.ensemble import GradientBoostingClassifier\n",
770+
"\n",
771+
"gb_clf = GradientBoostingClassifier()\n",
772+
"\n",
773+
"# to avoid having same UnitNumber in both sets\n",
774+
"cv = KFold(3)\n",
775+
"\n",
776+
"param_grid = { \"n_estimators\" : [100, 130, 150, 180, 200],\n",
777+
" \"learning_rate\" : [0.05, .1, 0.07]\n",
778+
" }\n",
779+
"\n",
780+
"\n",
781+
"\n",
782+
"optimized_gb_clf = GridSearchCV(estimator=gb_clf,\n",
783+
" cv = cv,\n",
784+
" param_grid=param_grid,\n",
785+
" verbose = 1,\n",
786+
" n_jobs = -1)\n",
787+
"\n",
788+
"# we train the best model with the full dataset\n",
789+
"optimized_gb_clf.fit(train_df_extra, train_y)"
790+
]
791+
},
792+
{
793+
"cell_type": "code",
794+
"execution_count": null,
595795
"metadata": {},
796+
"outputs": [],
596797
"source": [
597-
"## Evaluate the Performance in the Test Dataset\n",
798+
"# compute again the metrics on the validation set\n",
799+
"y_valid_clf_proba = optimized_gb_clf.predict_proba(valid_df_extra)[:, 1]\n",
800+
"y_valid_clf_pred = optimized_gb_clf.predict(valid_df_extra)\n",
598801
"\n",
599-
"Finally, we can evaluate the final performance on our holdout test dataset."
802+
"print(\"Confusion Matrix:\")\n",
803+
"print(confusion_matrix(valid_y,y_valid_clf_pred))\n",
804+
"print(\"Gradient Boosting Classifier Accuracy: \"+\"{:.1%}\".format(accuracy_score(valid_y,y_valid_clf_pred)));\n",
805+
"print(\"Classification Report:\")\n",
806+
"print(classification_report(valid_y,y_valid_clf_pred))"
807+
]
808+
},
809+
{
810+
"cell_type": "code",
811+
"execution_count": null,
812+
"metadata": {},
813+
"outputs": [],
814+
"source": [
815+
"y_valid_clf_pred"
816+
]
817+
},
818+
{
819+
"cell_type": "code",
820+
"execution_count": null,
821+
"metadata": {},
822+
"outputs": [],
823+
"source": [
824+
"valid_y != y_valid_clf_pred"
825+
]
826+
},
827+
{
828+
"cell_type": "code",
829+
"execution_count": null,
830+
"metadata": {},
831+
"outputs": [],
832+
"source": [
833+
"# it is possible that some of the mistakes are due to\n",
834+
"# incorrect data, we could potentially explore some of the \n",
835+
"wrongly_classified_ids = (valid_df_extra.loc[(valid_y != y_valid_clf_pred).values]).index\n",
836+
"wrongly_classified_ids"
837+
]
838+
},
839+
{
840+
"cell_type": "code",
841+
"execution_count": null,
842+
"metadata": {},
843+
"outputs": [],
844+
"source": [
845+
"# we could check a few of the incorrectly classified to see \n",
846+
"spell_id = 'reparo/Serious_4d7e9d0d.fa4574'\n",
847+
"valid_df.loc[(slice(None),slice(None),slice(None), spell_id)].reset_index(drop=True).plot()"
848+
]
849+
},
850+
{
851+
"cell_type": "code",
852+
"execution_count": null,
853+
"metadata": {},
854+
"outputs": [],
855+
"source": [
856+
"# feel free to carry out additonal work and/or train additional models"
600857
]
601858
},
602859
{
@@ -638,7 +895,7 @@
638895
"metadata": {
639896
"celltoolbar": "Slideshow",
640897
"kernelspec": {
641-
"display_name": "Python 3",
898+
"display_name": "Python 3 (ipykernel)",
642899
"language": "python",
643900
"name": "python3"
644901
},
@@ -652,7 +909,7 @@
652909
"name": "python",
653910
"nbconvert_exporter": "python",
654911
"pygments_lexer": "ipython3",
655-
"version": "3.8.8"
912+
"version": "3.8.12"
656913
},
657914
"rise": {
658915
"theme": "black"

0 commit comments

Comments
 (0)