|
| 1 | +# Python data loader to generate CSV |
| 2 | + |
| 3 | +Here’s a Python data loader that performs logistic regression to classify penguin species based on bill and body size measurements, then outputs a CSV file to standard out. |
| 4 | + |
| 5 | +```python |
| 6 | +import pandas as pd |
| 7 | +from sklearn.linear_model import LogisticRegression |
| 8 | +import sys |
| 9 | + |
| 10 | +# Read the CSV |
| 11 | +df = pd.read_csv("src/data/penguins.csv") |
| 12 | + |
| 13 | +# Select columns to train the model |
| 14 | +X = df.iloc[:, [2, 3, 4, 5]] |
| 15 | +Y = df.iloc[:, 0] |
| 16 | + |
| 17 | +# Create an instance of Logistic Regression Classifier and fit the data. |
| 18 | +logreg = LogisticRegression() |
| 19 | +logreg.fit(X, Y) |
| 20 | + |
| 21 | +results = df.copy(); |
| 22 | +# Add predicted values |
| 23 | +results['species_predicted'] = logreg.predict(X) |
| 24 | + |
| 25 | +# Write to CSV |
| 26 | +results.to_csv(sys.stdout) |
| 27 | +``` |
| 28 | + |
| 29 | +<div class="note"> |
| 30 | + |
| 31 | +To run this data loader, you’ll need python3 and the geopandas, matplotlib, io, and sys modules installed and available on your `$PATH`. We recommend setting up a virtual environment. |
| 32 | + |
| 33 | +</div> |
| 34 | + |
| 35 | +To start and activate a virtual Python environment, run the following commands: |
| 36 | + |
| 37 | +``` |
| 38 | +$ python3 -m venv .venv |
| 39 | +$ source .venv/bin/activate |
| 40 | +``` |
| 41 | + |
| 42 | +Then install the required modules from `requirements.txt` using: |
| 43 | + |
| 44 | +``` |
| 45 | +$ pip install -r requirements.txt |
| 46 | +``` |
| 47 | + |
| 48 | +The above data loader lives in `data/predictions.csv.py`, so we can load the data using `data/predictions.csv` with `FileAttachment`: |
| 49 | + |
| 50 | +```js echo |
| 51 | +const predictions = FileAttachment("data/predictions.csv").csv({typed: true}); |
| 52 | +``` |
| 53 | + |
| 54 | +We can create a quick chart of predicted species, highlighting cases where penguins are misclassified, using Observable Plot: |
| 55 | + |
| 56 | +```js echo |
| 57 | +Plot.plot({ |
| 58 | + grid: true, |
| 59 | + height: 400, |
| 60 | + caption: "Incorrect predictions highlighted with diamonds. Actual species encoded with color and predicted species encoded with symbols.", |
| 61 | + color: { |
| 62 | + legend: true, |
| 63 | + }, |
| 64 | + x: {label: "Culmen length (mm)"}, |
| 65 | + y: {label: "Culmen depth (mm)"}, |
| 66 | + marks: [ |
| 67 | + Plot.dot(predictions, { |
| 68 | + x: "culmen_length_mm", |
| 69 | + y: "culmen_depth_mm", |
| 70 | + stroke: "species", |
| 71 | + symbol: "species_predicted", |
| 72 | + r: 3, |
| 73 | + tip: {channels: {"mass": "body_mass_g"}} |
| 74 | + }), |
| 75 | + Plot.dot(predictions, { |
| 76 | + filter: (d) => d.species !== d.species_predicted, |
| 77 | + x: "culmen_length_mm", |
| 78 | + y: "culmen_depth_mm", |
| 79 | + r: 7, |
| 80 | + symbol: "diamond", |
| 81 | + stroke: "currentColor" |
| 82 | + }) |
| 83 | + ], |
| 84 | +}) |
| 85 | +``` |
0 commit comments