generated from amazon-archives/__template_DevGuide
-
Notifications
You must be signed in to change notification settings - Fork 182
Open
Labels
Description
I am trying to train a ResNet model from scratch on Trainium and received an unhandled exception with message: allocated memory out of bound during the compilation. I have attached the Jupyter Notebook code that was running along with the log file from Trainium.
# Import Required Libraries
from typing import Optional
import os
use_trainium = True
wandb_key: Optional[str] # declare that this python variables is a string or None
wandb_key = os.environ.get('WANDB_API_KEY')
if wandb_key is not None:
print("weights and biases API key obtained from environment")
from fastai.data.all import *
from fastai.vision.all import *
import numpy as np
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
import random
from PIL import Image
from skimage import io
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, Dataset
import torch.optim as optim
import torchvision.datasets
import torchvision.transforms as T
from torchvision.io import read_image
from torchvision.datasets import DatasetFolder
from torchvision.datasets.folder import default_loader
import wandb
from sklearn.metrics import confusion_matrix
if use_trainium:
import torch_xla.core.xla_model as xm
# Trainium or Google TPU
device = xm.xla_device()
else:
import torch.backends.cudnn as cudnn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cudnn.benchmark = True
print(f"Using {device} device")
wandb_login = False
if wandb_key is not None:
wandb_login = wandb.login(key=wandb_key, relogin=True, verify=True)
else:
wandb_login = wandb.login(relogin=True, verify=True) # prompting login
if wandb_login is True:
print("Weights and Biases login successful")
else:
print("Weights and Biases login failed")
# Loading The Data
meta_df = pd.read_csv('data/traffic-signs/Meta.csv')
train_df = pd.read_csv('data/traffic-signs/Train.csv')
test_df = pd.read_csv('data/traffic-signs/Test.csv')
data_path = 'data/traffic-signs/'
train_data_path = os.path.join(data_path, 'Train')
test_data_path = os.path.join(data_path, 'Test')
meta_data_path = os.path.join(data_path, 'Meta')
# Exploring The Data
#exploring Meta.csv file
print("")
print("------------------------------------------------")
print(meta_df.head())
print("------------------------------------------------")
print("number of classes in the dataset:",meta_df.ClassId.nunique())
print("number of Shape Ids in the dataset:",meta_df.ShapeId.nunique())
print("number of Color Ids in the dataset:",meta_df.ColorId.nunique())
print("number of Sign Ids in the dataset:",meta_df.SignId.nunique())
print("")
## Plotting the original Traffic signs
signs = [os.path.join(data_path, meta_df.Path.to_list()[i]) for i in range(43)]
fig, axes = plt.subplots(11, 4, figsize=(15, 10))
for i, image_path in enumerate(signs):
image = Image.open(image_path)
row = i // 4
col = i % 4
axes[row, col].imshow(image)
axes[row, col].axis('off')
plt.show();
#exploring Train.csv file
print("")
print("------------------------------------------------")
print(train_df.head())
print("------------------------------------------------")
print("number of Training Samples in the dataset:",train_df.shape[0])
print("number of Test Samples in the dataset:",test_df.shape[0])
print("number of Classes in the dataset:",train_df["ClassId"].nunique())
print("The Maximum Width:",train_df["Width"].max())
print("The Maximum Height:",train_df["Height"].max())
## The Distribution of the Class labels in the dataset
classes = train_df["ClassId"].value_counts().head(43)
plt.figure(figsize=(12,6))
plt.title("Distribution of Class Labels in the dataset")
plt.ylabel('Counts')
plt.xlabel('Classes')
sns.barplot(y=classes.values, x=classes.index,color='g');
## Plotting samples of the Traffic signs
data_path = 'data/traffic-signs/'
train_data_path = os.path.join(data_path, 'Train')
valid_data_path = os.path.join(data_path, 'Test')
folder_names = [os.path.join(train_data_path, str(i)) for i in random.choices(range(43), k=20)]
file_names = [os.path.join(fldr, os.listdir(fldr)[0]) for fldr in folder_names]
fig, axes = plt.subplots(4, 5, figsize=(12, 8))
for i, image_path in enumerate(file_names):
image = Image.open(image_path)
row = i // 5
col = i % 5
axes[row, col].imshow(image)
axes[row, col].axis('off')
plt.show();
# Building Custom Dataset for Traffic signs
# Transforming the Data ToTensor and Normalize it
transforms = T.Compose([T.ToTensor(),T.Resize((225,225)),
T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
class TSignsDataset(Dataset):
def __init__(self, df, root_dir,transform=None):
self.df = df
self.root_dir = root_dir
self.transform = transform
def __len__(self):
return len(self.df)
def __getitem__(self,index):
image_path = os.path.join(self.root_dir,self.df.iloc[index,7]) #the column of paths in df is 7
image = Image.open(image_path)
y_class = torch.tensor(self.df.iloc[index, 6]) #the column of ClsassId in df is 6
if self.transform:
image = self.transform(image)
return (image, y_class)
training_set = TSignsDataset(train_df,data_path,transform=transforms)
validation_set = TSignsDataset(test_df,data_path,transform=transforms)
# Loading The data into DataLoaders
#Loading the data into DataLoader
train_loader = DataLoader(dataset=training_set, batch_size=32, shuffle=True)
valid_loader = DataLoader(dataset=validation_set, batch_size=32, shuffle=False)
dataloaders = {'training':train_loader,'validation':valid_loader}
dataset_sizes = {'training':len(train_loader.dataset),'validation':len(valid_loader.dataset)}
print(dataset_sizes)
# Building The ResNet Model from scratch
### Generic Residual block
class block(nn.Module):
def __init__(
self, in_channels, out_channels, identity_downsample=None, stride=1):
super().__init__()
self.expansion = 4
self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=1,stride=1,padding=0,bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=stride,padding=1,bias=False,)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels,out_channels * self.expansion,kernel_size=1,stride=1,padding=0,bias=False,)
self.bn3 = nn.BatchNorm2d(out_channels * self.expansion)
self.relu = nn.ReLU()
self.identity_downsample = identity_downsample
self.stride = stride
def forward(self, x):
identity = x.clone()
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
if self.identity_downsample is not None:
identity = self.identity_downsample(identity)
x += identity
x = self.relu(x)
return x
### Generic implementation of ResNet Class
class ResNet(nn.Module):
def __init__(self, block, layers, image_channels, num_classes):
super(ResNet, self).__init__()
self.in_channels = 64
self.conv1 = nn.Conv2d(image_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Essentially the entire ResNet architecture are in these 4 lines below
self.layer1 = self._make_layer(block, layers[0], out_channels=64, stride=1)
self.layer2 = self._make_layer(block, layers[1], out_channels=128, stride=2)
self.layer3 = self._make_layer(block, layers[2], out_channels=256, stride=2)
self.layer4 = self._make_layer(block, layers[3], out_channels=512, stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * 4, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.reshape(x.shape[0], -1)
x = self.fc(x)
return x
def _make_layer(self, block, num_residual_blocks, out_channels, stride):
identity_downsample = None
layers = []
# Either if we half the input space for ex, 56x56 -> 28x28 (stride=2), or channels changes
# we need to adapt the Identity (skip connection) so it will be able to be added
# to the layer that's ahead
if stride != 1 or self.in_channels != out_channels * 4:
identity_downsample = nn.Sequential(nn.Conv2d(self.in_channels,out_channels * 4,kernel_size=1,stride=stride,bias=False)
,nn.BatchNorm2d(out_channels * 4))
layers.append(block(self.in_channels, out_channels, identity_downsample, stride))
# The expansion size is always 4 for ResNet 50,101,152
self.in_channels = out_channels * 4
# For example for first resnet layer: 256 will be mapped to 64 as intermediate layer,
# then finally back to 256. Hence no identity downsample is needed, since stride = 1,
# and also same amount of channels.
for i in range(num_residual_blocks - 1):
layers.append(block(self.in_channels, out_channels))
return nn.Sequential(*layers)
### The ResNet: 3 levels of depth
def ResNet50(img_channel=3, num_classes=1000):
return ResNet(block, [3, 4, 6, 3], img_channel, num_classes)
def ResNet101(img_channel=3, num_classes=1000):
return ResNet(block, [3, 4, 23, 3], img_channel, num_classes)
def ResNet152(img_channel=3, num_classes=1000):
return ResNet(block, [3, 8, 36, 3], img_channel, num_classes)
# Training The model
def Train(model,criterion,optimizer,num_epochs,batch_size,dataloaders):
best_model_weights = model.state_dict()
best_acc = 0.0
for epoch in range(num_epochs):
print("epoch {}/{}".format(epoch+1,num_epochs))
print("*" * 10)
for x in ["training","validation"]:
if x == "training" :
model.train()
else:
model.eval()
running_loss = 0.0
running_accuracy = 0
for data in dataloaders[x]:
img , y = data
img , y = img.to(device) , y.to(device)
optimizer.zero_grad()
y_pred = model(img)
loss = criterion(y_pred,y)
_, preds = torch.max(y_pred, dim=1)
if x == 'training':
loss.backward()
optimizer.step()
if use_trainium:
xm.mark_step()
#print("marked step for XM device")
running_loss += loss.item()
running_accuracy += torch.sum(preds == y.data)
epoch_loss = running_loss / dataset_sizes[x]
epoch_acc = running_accuracy / dataset_sizes[x]
print('{} Loss: {:.4f} || Accuracy: {:.4f}'.format(x, epoch_loss, epoch_acc))
if wandb_login:
if x == 'training':
wandb.log({"train_accuracy": epoch_acc, "train_loss": epoch_loss})
else:
wandb.log({"accuracy": epoch_acc, "loss": epoch_loss})
# deep copy the model
if x == 'validation' and epoch_acc > best_acc:
best_acc = epoch_acc
# load best model weights
return print('Best validation Accuracy: {:4f}'.format(best_acc))
num_epochs = 10
if wandb_login:
#torchexplorer.setup() # gives us additional visualization
# start a new wandb run to track this script
wandb.init(
# set the wandb project where this run will be logged
project="signs-from-scratch",
# track hyperparameters and run metadata
config={
"architecture": "resnet18",
"dataset": "hymenoptera_data",
"epochs": num_epochs,
}
)
model = ResNet50(img_channel=3, num_classes=43).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),lr = 0.001)
batch_size = 32
#train the model
Train(model,criterion,optimizer,num_epochs,batch_size,dataloaders)
if wandb_login:
wandb.finish(exit_code=0)
# Model evaluation
learn = vision_learner(dataloaders, model)
interp = ClassificationInterpretation.from_learner(learn)
interp.plot_confusion_matrix(figsize=(12,12), dpi=60)
# References
* Deep Residual Learning for Image Recognition: https://arxiv.org/abs/1512.03385
* ResNet Explained :https://www.analyticsvidhya.com/blog/2023/02/deep-residual-learning-for-image-recognition-resnet-explained/
* Pytorch ResNet implementation from Scratch: https://www.youtube.com/watch?v=DkNIBBBvcPs