-
Notifications
You must be signed in to change notification settings - Fork 97
Open
Description
After retraining the model using the code in this repository, I have attempted to use it in order to segment one of the images in the CrackForest db. The code is taken from the pytorch page:
import torch
model = torch.load('output/weights.pt')
model.eval()
import urllib
url, filename = ("file:///home/rhobincu/gitroot/DeepLabv3FineTuning/CrackForest/Images/092.jpg", "092.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
# sample execution (requires torchvision)
from PIL import Image
from torchvision import transforms
input_image = Image.open(filename)
print(input_image)
preprocess = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
# move the input and model to GPU for speed if available
if torch.cuda.is_available():
print('Using GPU!')
input_batch = input_batch.to('cuda')
model.to('cuda')
import time
start = time.clock()
with torch.no_grad():
output = model(input_batch)['out'][0]
end = time.clock()
print('Inference duration (s): ', end - start)
output_predictions = output.argmax(0)
# create a color pallette, selecting a color for each class
palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2])
colors = torch.as_tensor([i for i in range(2)])[:, None] * palette
colors = (colors % 255).numpy().astype("uint8")
print(colors)
# plot the semantic segmentation predictions of 21 classes in each color
img_size = input_image.size
data = output_predictions.byte().cpu().numpy()
print(data)
print(data.sum())
r = Image.fromarray(data).resize(img_size)
r.putpalette(colors)
#cv2.imshow('image',input_image)
import matplotlib.pyplot as plt
plt.imshow(r)
plt.show()The problem is that the output from the nn (data) is full of 0. Any ideas?
Metadata
Metadata
Assignees
Labels
No labels