Skip to content

Commit 158f061

Browse files
Programmer-RD-AIProgrammer-RD-AI
andcommitted
new help_funcs
Co-Authored-By: Ranuga <[email protected]>
1 parent 42682f1 commit 158f061

File tree

2 files changed

+294
-0
lines changed

2 files changed

+294
-0
lines changed
8.01 KB
Binary file not shown.

helper_functions.py

Lines changed: 294 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,294 @@
1+
"""
2+
A series of helper functions used throughout the course.
3+
4+
If a function gets defined once and could be used over and over, it'll go in here.
5+
"""
6+
import torch
7+
import matplotlib.pyplot as plt
8+
import numpy as np
9+
10+
from torch import nn
11+
12+
import os
13+
import zipfile
14+
15+
from pathlib import Path
16+
17+
import requests
18+
19+
# Walk through an image classification directory and find out how many files (images)
20+
# are in each subdirectory.
21+
import os
22+
23+
def walk_through_dir(dir_path):
24+
"""
25+
Walks through dir_path returning its contents.
26+
Args:
27+
dir_path (str): target directory
28+
29+
Returns:
30+
A print out of:
31+
number of subdiretories in dir_path
32+
number of images (files) in each subdirectory
33+
name of each subdirectory
34+
"""
35+
for dirpath, dirnames, filenames in os.walk(dir_path):
36+
print(f"There are {len(dirnames)} directories and {len(filenames)} images in '{dirpath}'.")
37+
38+
def plot_decision_boundary(model: torch.nn.Module, X: torch.Tensor, y: torch.Tensor):
39+
"""Plots decision boundaries of model predicting on X in comparison to y.
40+
41+
Source - https://madewithml.com/courses/foundations/neural-networks/ (with modifications)
42+
"""
43+
# Put everything to CPU (works better with NumPy + Matplotlib)
44+
model.to("cpu")
45+
X, y = X.to("cpu"), y.to("cpu")
46+
47+
# Setup prediction boundaries and grid
48+
x_min, x_max = X[:, 0].min() - 0.1, X[:, 0].max() + 0.1
49+
y_min, y_max = X[:, 1].min() - 0.1, X[:, 1].max() + 0.1
50+
xx, yy = np.meshgrid(np.linspace(x_min, x_max, 101), np.linspace(y_min, y_max, 101))
51+
52+
# Make features
53+
X_to_pred_on = torch.from_numpy(np.column_stack((xx.ravel(), yy.ravel()))).float()
54+
55+
# Make predictions
56+
model.eval()
57+
with torch.inference_mode():
58+
y_logits = model(X_to_pred_on)
59+
60+
# Test for multi-class or binary and adjust logits to prediction labels
61+
if len(torch.unique(y)) > 2:
62+
y_pred = torch.softmax(y_logits, dim=1).argmax(dim=1) # mutli-class
63+
else:
64+
y_pred = torch.round(torch.sigmoid(y_logits)) # binary
65+
66+
# Reshape preds and plot
67+
y_pred = y_pred.reshape(xx.shape).detach().numpy()
68+
plt.contourf(xx, yy, y_pred, cmap=plt.cm.RdYlBu, alpha=0.7)
69+
plt.scatter(X[:, 0], X[:, 1], c=y, s=40, cmap=plt.cm.RdYlBu)
70+
plt.xlim(xx.min(), xx.max())
71+
plt.ylim(yy.min(), yy.max())
72+
73+
74+
# Plot linear data or training and test and predictions (optional)
75+
def plot_predictions(
76+
train_data, train_labels, test_data, test_labels, predictions=None
77+
):
78+
"""
79+
Plots linear training data and test data and compares predictions.
80+
"""
81+
plt.figure(figsize=(10, 7))
82+
83+
# Plot training data in blue
84+
plt.scatter(train_data, train_labels, c="b", s=4, label="Training data")
85+
86+
# Plot test data in green
87+
plt.scatter(test_data, test_labels, c="g", s=4, label="Testing data")
88+
89+
if predictions is not None:
90+
# Plot the predictions in red (predictions were made on the test data)
91+
plt.scatter(test_data, predictions, c="r", s=4, label="Predictions")
92+
93+
# Show the legend
94+
plt.legend(prop={"size": 14})
95+
96+
97+
# Calculate accuracy (a classification metric)
98+
def accuracy_fn(y_true, y_pred):
99+
"""Calculates accuracy between truth labels and predictions.
100+
101+
Args:
102+
y_true (torch.Tensor): Truth labels for predictions.
103+
y_pred (torch.Tensor): Predictions to be compared to predictions.
104+
105+
Returns:
106+
[torch.float]: Accuracy value between y_true and y_pred, e.g. 78.45
107+
"""
108+
correct = torch.eq(y_true, y_pred).sum().item()
109+
acc = (correct / len(y_pred)) * 100
110+
return acc
111+
112+
113+
def print_train_time(start, end, device=None):
114+
"""Prints difference between start and end time.
115+
116+
Args:
117+
start (float): Start time of computation (preferred in timeit format).
118+
end (float): End time of computation.
119+
device ([type], optional): Device that compute is running on. Defaults to None.
120+
121+
Returns:
122+
float: time between start and end in seconds (higher is longer).
123+
"""
124+
total_time = end - start
125+
print(f"\nTrain time on {device}: {total_time:.3f} seconds")
126+
return total_time
127+
128+
129+
# Plot loss curves of a model
130+
def plot_loss_curves(results):
131+
"""Plots training curves of a results dictionary.
132+
133+
Args:
134+
results (dict): dictionary containing list of values, e.g.
135+
{"train_loss": [...],
136+
"train_acc": [...],
137+
"test_loss": [...],
138+
"test_acc": [...]}
139+
"""
140+
loss = results["train_loss"]
141+
test_loss = results["test_loss"]
142+
143+
accuracy = results["train_acc"]
144+
test_accuracy = results["test_acc"]
145+
146+
epochs = range(len(results["train_loss"]))
147+
148+
plt.figure(figsize=(15, 7))
149+
150+
# Plot loss
151+
plt.subplot(1, 2, 1)
152+
plt.plot(epochs, loss, label="train_loss")
153+
plt.plot(epochs, test_loss, label="test_loss")
154+
plt.title("Loss")
155+
plt.xlabel("Epochs")
156+
plt.legend()
157+
158+
# Plot accuracy
159+
plt.subplot(1, 2, 2)
160+
plt.plot(epochs, accuracy, label="train_accuracy")
161+
plt.plot(epochs, test_accuracy, label="test_accuracy")
162+
plt.title("Accuracy")
163+
plt.xlabel("Epochs")
164+
plt.legend()
165+
166+
167+
# Pred and plot image function from notebook 04
168+
# See creation: https://www.learnpytorch.io/04_pytorch_custom_datasets/#113-putting-custom-image-prediction-together-building-a-function
169+
from typing import List
170+
import torchvision
171+
172+
173+
def pred_and_plot_image(
174+
model: torch.nn.Module,
175+
image_path: str,
176+
class_names: List[str] = None,
177+
transform=None,
178+
device: torch.device = "cuda" if torch.cuda.is_available() else "cpu",
179+
):
180+
"""Makes a prediction on a target image with a trained model and plots the image.
181+
182+
Args:
183+
model (torch.nn.Module): trained PyTorch image classification model.
184+
image_path (str): filepath to target image.
185+
class_names (List[str], optional): different class names for target image. Defaults to None.
186+
transform (_type_, optional): transform of target image. Defaults to None.
187+
device (torch.device, optional): target device to compute on. Defaults to "cuda" if torch.cuda.is_available() else "cpu".
188+
189+
Returns:
190+
Matplotlib plot of target image and model prediction as title.
191+
192+
Example usage:
193+
pred_and_plot_image(model=model,
194+
image="some_image.jpeg",
195+
class_names=["class_1", "class_2", "class_3"],
196+
transform=torchvision.transforms.ToTensor(),
197+
device=device)
198+
"""
199+
200+
# 1. Load in image and convert the tensor values to float32
201+
target_image = torchvision.io.read_image(str(image_path)).type(torch.float32)
202+
203+
# 2. Divide the image pixel values by 255 to get them between [0, 1]
204+
target_image = target_image / 255.0
205+
206+
# 3. Transform if necessary
207+
if transform:
208+
target_image = transform(target_image)
209+
210+
# 4. Make sure the model is on the target device
211+
model.to(device)
212+
213+
# 5. Turn on model evaluation mode and inference mode
214+
model.eval()
215+
with torch.inference_mode():
216+
# Add an extra dimension to the image
217+
target_image = target_image.unsqueeze(dim=0)
218+
219+
# Make a prediction on image with an extra dimension and send it to the target device
220+
target_image_pred = model(target_image.to(device))
221+
222+
# 6. Convert logits -> prediction probabilities (using torch.softmax() for multi-class classification)
223+
target_image_pred_probs = torch.softmax(target_image_pred, dim=1)
224+
225+
# 7. Convert prediction probabilities -> prediction labels
226+
target_image_pred_label = torch.argmax(target_image_pred_probs, dim=1)
227+
228+
# 8. Plot the image alongside the prediction and prediction probability
229+
plt.imshow(
230+
target_image.squeeze().permute(1, 2, 0)
231+
) # make sure it's the right size for matplotlib
232+
if class_names:
233+
title = f"Pred: {class_names[target_image_pred_label.cpu()]} | Prob: {target_image_pred_probs.max().cpu():.3f}"
234+
else:
235+
title = f"Pred: {target_image_pred_label} | Prob: {target_image_pred_probs.max().cpu():.3f}"
236+
plt.title(title)
237+
plt.axis(False)
238+
239+
def set_seeds(seed: int=42):
240+
"""Sets random sets for torch operations.
241+
242+
Args:
243+
seed (int, optional): Random seed to set. Defaults to 42.
244+
"""
245+
# Set the seed for general torch operations
246+
torch.manual_seed(seed)
247+
# Set the seed for CUDA torch operations (ones that happen on the GPU)
248+
torch.cuda.manual_seed(seed)
249+
250+
def download_data(source: str,
251+
destination: str,
252+
remove_source: bool = True) -> Path:
253+
"""Downloads a zipped dataset from source and unzips to destination.
254+
255+
Args:
256+
source (str): A link to a zipped file containing data.
257+
destination (str): A target directory to unzip data to.
258+
remove_source (bool): Whether to remove the source after downloading and extracting.
259+
260+
Returns:
261+
pathlib.Path to downloaded data.
262+
263+
Example usage:
264+
download_data(source="https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip",
265+
destination="pizza_steak_sushi")
266+
"""
267+
# Setup path to data folder
268+
data_path = Path("data/")
269+
image_path = data_path / destination
270+
271+
# If the image folder doesn't exist, download it and prepare it...
272+
if image_path.is_dir():
273+
print(f"[INFO] {image_path} directory exists, skipping download.")
274+
else:
275+
print(f"[INFO] Did not find {image_path} directory, creating one...")
276+
image_path.mkdir(parents=True, exist_ok=True)
277+
278+
# Download pizza, steak, sushi data
279+
target_file = Path(source).name
280+
with open(data_path / target_file, "wb") as f:
281+
request = requests.get(source)
282+
print(f"[INFO] Downloading {target_file} from {source}...")
283+
f.write(request.content)
284+
285+
# Unzip pizza, steak, sushi data
286+
with zipfile.ZipFile(data_path / target_file, "r") as zip_ref:
287+
print(f"[INFO] Unzipping {target_file} data...")
288+
zip_ref.extractall(image_path)
289+
290+
# Remove .zip file
291+
if remove_source:
292+
os.remove(data_path / target_file)
293+
294+
return image_path

0 commit comments

Comments
 (0)