Skip to content

Commit 0243788

Browse files
authored
Merge pull request #50564 from ShawnKupfer/WB1796
AB#1016252: Fix graphing code in Confusion matrix and data imbalances
2 parents f813783 + a18c670 commit 0243788

File tree

3 files changed

+253
-186
lines changed

3 files changed

+253
-186
lines changed

learn-pr/azure/machine-learning-confusion-matrix/notebooks/8-3-exercise-build-matrix.ipynb

Lines changed: 74 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,19 @@
5757
"metadata": {},
5858
"outputs": [],
5959
"source": [
60-
"import graphing # custom graphing code. See our GitHub repo for details\n",
60+
"import matplotlib.pyplot as plt\n",
6161
"\n",
62-
"# Plot a histogram with counts for each label\n",
63-
"graphing.multiple_histogram(dataset, label_x=\"label\", label_group=\"label\", title=\"Label distribution\")"
62+
"# Count the number of occurrences of each label\n",
63+
"label_counts = dataset['label'].value_counts()\n",
64+
"\n",
65+
"# Plot the histogram\n",
66+
"label_counts.plot(kind='bar')\n",
67+
"\n",
68+
"plt.title(\"Label distribution\")\n",
69+
"plt.xlabel(\"Label\")\n",
70+
"plt.ylabel(\"Count\")\n",
71+
"plt.tight_layout()\n",
72+
"plt.show()"
6473
]
6574
},
6675
{
@@ -86,7 +95,16 @@
8695
"outputs": [],
8796
"source": [
8897
"# Plot a histogram with counts for each label\n",
89-
"graphing.multiple_histogram(dataset, label_x=\"color\", label_group=\"color\", title=\"Color distribution\")"
98+
"color_counts = dataset['color'].value_counts()\n",
99+
"\n",
100+
"# Plot the histogram\n",
101+
"color_counts.plot(kind='bar', color='skyblue', edgecolor='black')\n",
102+
"\n",
103+
"plt.title(\"Color distribution\")\n",
104+
"plt.xlabel(\"Color\")\n",
105+
"plt.ylabel(\"Count\")\n",
106+
"plt.tight_layout()\n",
107+
"plt.show()"
90108
]
91109
},
92110
{
@@ -110,7 +128,12 @@
110128
"metadata": {},
111129
"outputs": [],
112130
"source": [
113-
"graphing.box_and_whisker(dataset, label_y=\"size\", title='Boxplot of \"size\" feature')"
131+
"plt.boxplot(dataset['size'].dropna(), vert=True)\n",
132+
"\n",
133+
"plt.title('Boxplot of \"size\" feature')\n",
134+
"plt.ylabel('Size')\n",
135+
"plt.tight_layout()\n",
136+
"plt.show()"
114137
]
115138
},
116139
{
@@ -131,7 +154,12 @@
131154
"metadata": {},
132155
"outputs": [],
133156
"source": [
134-
"graphing.box_and_whisker(dataset, label_y=\"roughness\", title='Boxplot of \"roughness\" feature')"
157+
"plt.boxplot(dataset['roughness'].dropna(), vert=True)\n",
158+
"\n",
159+
"plt.title('Boxplot of \"roughness\" feature')\n",
160+
"plt.ylabel('Roughness')\n",
161+
"plt.tight_layout()\n",
162+
"plt.show()"
135163
]
136164
},
137165
{
@@ -151,7 +179,12 @@
151179
"metadata": {},
152180
"outputs": [],
153181
"source": [
154-
"graphing.box_and_whisker(dataset, label_y=\"motion\", title='Boxplot of \"motion\" feature')"
182+
"plt.boxplot(dataset['motion'].dropna(), vert=True)\n",
183+
"\n",
184+
"plt.title('Boxplot of \"motion\" feature')\n",
185+
"plt.ylabel('Motion')\n",
186+
"plt.tight_layout()\n",
187+
"plt.show()"
155188
]
156189
},
157190
{
@@ -325,41 +358,40 @@
325358
"metadata": {},
326359
"outputs": [],
327360
"source": [
328-
"# We use plotly to create plots and charts\n",
329-
"import plotly.figure_factory as ff\n",
330-
"\n",
331-
"# Create the list of unique labels in the test set, to use in our plot\n",
332-
"# I.e., ['animal', 'hiker', 'rock', 'tree']\n",
333-
"x = y = sorted(list(test[\"label\"].unique()))\n",
334-
"\n",
335-
"# Plot the matrix above as a heatmap with annotations (values) in its cells\n",
336-
"fig = ff.create_annotated_heatmap(cm, x, y)\n",
337-
"\n",
338-
"# Set titles and ordering\n",
339-
"fig.update_layout( title_text=\"<b>Confusion matrix</b>\", \n",
340-
" yaxis = dict(categoryorder = \"category descending\"))\n",
341-
"\n",
342-
"fig.add_annotation(dict(font=dict(color=\"black\",size=14),\n",
343-
" x=0.5,\n",
344-
" y=-0.15,\n",
345-
" showarrow=False,\n",
346-
" text=\"Predicted label\",\n",
347-
" xref=\"paper\",\n",
348-
" yref=\"paper\"))\n",
349-
"\n",
350-
"fig.add_annotation(dict(font=dict(color=\"black\",size=14),\n",
351-
" x=-0.15,\n",
352-
" y=0.5,\n",
353-
" showarrow=False,\n",
354-
" text=\"Actual label\",\n",
355-
" textangle=-90,\n",
356-
" xref=\"paper\",\n",
357-
" yref=\"paper\"))\n",
358-
"\n",
359-
"# We need margins so the titles fit\n",
360-
"fig.update_layout(margin=dict(t=80, r=20, l=100, b=50))\n",
361-
"fig['data'][0]['showscale'] = True\n",
362-
"fig.show()"
361+
"import numpy as np\n",
362+
"\n",
363+
"# Create sorted list of unique labels\n",
364+
"labels = sorted(list(test[\"label\"].unique()))\n",
365+
"\n",
366+
"fig, ax = plt.subplots(figsize=(8, 6))\n",
367+
"\n",
368+
"# Show the confusion matrix\n",
369+
"cax = ax.imshow(cm, interpolation='nearest', cmap='Blues')\n",
370+
"\n",
371+
"# Add colorbar\n",
372+
"fig.colorbar(cax)\n",
373+
"\n",
374+
"# Annotate each cell with the numeric value\n",
375+
"for i in range(len(cm)):\n",
376+
" for j in range(len(cm[i])):\n",
377+
" ax.text(j, i, format(cm[i][j], 'd'),\n",
378+
" ha=\"center\", va=\"center\",\n",
379+
" color=\"black\")\n",
380+
"\n",
381+
"# Set axis labels and ticks\n",
382+
"ax.set_xticks(np.arange(len(labels)))\n",
383+
"ax.set_yticks(np.arange(len(labels)))\n",
384+
"ax.set_xticklabels(labels)\n",
385+
"ax.set_yticklabels(labels)\n",
386+
"\n",
387+
"# Set title and axis labels\n",
388+
"plt.title(\"Confusion matrix\", fontsize=14, fontweight='bold')\n",
389+
"plt.xlabel(\"Predicted label\", fontsize=12)\n",
390+
"plt.ylabel(\"Actual label\", fontsize=12)\n",
391+
"\n",
392+
"# Adjust layout to fit labels\n",
393+
"plt.tight_layout()\n",
394+
"plt.show()"
363395
]
364396
},
365397
{

0 commit comments

Comments
 (0)