-
Notifications
You must be signed in to change notification settings - Fork 141
Description
`import cv2
import numpy as np
import onnxruntime as ort
import torch
import argparse
import os
import sys
sys.path.append(os.path.join(os.path.dirname(file), "../"))
from utils.config import Config
class UFLDv2:
def init(self, onnx_path, config_path, ori_size):
self.session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
cfg = Config.fromfile(config_path)
self.ori_img_w, self.ori_img_h = ori_size
self.cut_height = int(cfg.train_height * (1 - cfg.crop_ratio))
self.input_width = cfg.train_width
self.input_height = cfg.train_height
self.num_row = cfg.num_row
self.num_col = cfg.num_col
# self.row_anchor = np.linspace(0.42, 1, cfg.num_row)
# self.col_anchor = np.linspace(0, 1, cfg.num_col)
self.row_anchor = np.linspace(160,710, cfg.num_row)/720
self.col_anchor = np.linspace(0,1, cfg.num_col)
def pred2coords(self, pred):
batch_size, num_grid_row, num_cls_row, num_lane_row = pred['loc_row'].shape
batch_size, num_grid_col, num_cls_col, num_lane_col = pred['loc_col'].shape
max_indices_row = pred['loc_row'].argmax(1)
valid_row = pred['exist_row'].argmax(1)
max_indices_col = pred['loc_col'].argmax(1)
valid_col = pred['exist_col'].argmax(1)
coords = []
row_lane_idx = [1, 2]
col_lane_idx = [0, 3]
for i in row_lane_idx:
tmp = []
if valid_row[0, :, i].sum() > num_cls_row / 2:
for k in range(valid_row.shape[1]):
if valid_row[0, k, i]:
all_ind = torch.tensor(list(range(max(0, max_indices_row[0, k, i] - self.input_width),
min(num_grid_row - 1,
max_indices_row[0, k, i] + self.input_width) + 1)))
out_tmp = (pred['loc_row'][0, all_ind, k, i].softmax(0) * all_ind.float()).sum() + 0.5
out_tmp = out_tmp / (num_grid_row - 1) * self.ori_img_w
tmp.append((int(out_tmp), int(self.row_anchor[k] * self.ori_img_h)))
coords.append(tmp)
for i in col_lane_idx:
tmp = []
if valid_col[0, :, i].sum() > num_cls_col / 4:
for k in range(valid_col.shape[1]):
if valid_col[0, k, i]:
all_ind = torch.tensor(list(range(max(0, max_indices_col[0, k, i] - self.input_width),
min(num_grid_col - 1,
max_indices_col[0, k, i] + self.input_width) + 1)))
out_tmp = (pred['loc_col'][0, all_ind, k, i].softmax(0) * all_ind.float()).sum() + 0.5
out_tmp = out_tmp / (num_grid_col - 1) * self.ori_img_h
tmp.append((int(self.col_anchor[k] * self.ori_img_w), int(out_tmp)))
coords.append(tmp)
return coords
def forward(self, img):
im0 = img.copy()
img = img[self.cut_height:, :, :]
img = cv2.resize(img, (self.input_width, self.input_height), cv2.INTER_CUBIC)
img = img.astype(np.float32) / 255.0
img = np.transpose(np.float32(img[:, :, :, np.newaxis]), (3, 2, 0, 1))
img = np.ascontiguousarray(img)
input_name = self.session.get_inputs()[0].name
output_names = [output.name for output in self.session.get_outputs()]
outputs = self.session.run(output_names, {input_name: img})
preds = {}
for name, output in zip(output_names, outputs):
preds[name] = torch.tensor(output)
coords = self.pred2coords(preds)
for lane in coords:
for coord in lane:
cv2.circle(im0, coord, 2, (0, 255, 0), -1)
cv2.imshow("Result", im0)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--config_path', default='/home/akshat/Nidhin_wrkspace/Lane Detection/Ultra-Fast-Lane-Detection-v2/configs/tusimple_res18.py', help='path to config file', type=str)
parser.add_argument('--onnx_path', default='/home/akshat/Nidhin_wrkspace/Lane Detection/Ultra-Fast-Lane-Detection-v2/ufldv2_tusimple_res18_320x800.onnx', help='path to ONNX file', type=str)
parser.add_argument('--video_path', default='example.mp4', help='path to video file', type=str)
# parser.add_argument('--ori_size', default=(1600, 320), help='size of original frame', type=tuple)
parser.add_argument('--ori_size', default=(800, 320), help='size of original frame', type=tuple)
return parser.parse_args()
if name == "main":
args = get_args()
cap = cv2.VideoCapture(args.video_path)
isnet = UFLDv2(args.onnx_path, args.config_path, args.ori_size)
while True:
success, img = cap.read()
if not success:
print("Failed to read frame")
break
print(f"Image shape before resize: {img.shape}")
img = cv2.resize(img, (800, 451))
img = img[190:350, :, :]
isnet.forward(img)
if cv2.waitKey(25) & 0xFF == ord('q'):
break
cap.release()
cv2.destroyAllWindows()
`
