-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathroidetect.py
More file actions
executable file
·457 lines (368 loc) · 16 KB
/
roidetect.py
File metadata and controls
executable file
·457 lines (368 loc) · 16 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
from networkx.algorithms.link_prediction import cn_soundarajan_hopcroft
import os
# os.environ['OPENBLAS_NUM_THREADS'] = '17' # PLEASE CHANGE ME!
import numpy as np
import cv2
from numpy.lib.function_base import append
import sknw
from skimage.morphology import skeletonize
import matplotlib.pyplot as plt
import matplotlib as mpl
import networkx as nx
from sklearn.cluster import KMeans
import copy
from cv2 import aruco
import csv
import argparse
from PIL import Image
import bbox
def create_aruco_coords(infile, outfile):
"""
Inputs
infile -- a picture or video file
outfile -- empty text file
Outputs
outfile -- text file with ARTag coordinates for reference image, used as points of reference to apply homography
"""
try:
frame = Image.open(infile)
frame = np.array(frame)
except IOError:
print("Not an image, trying as a video file")
video = cv2.VideoCapture(infile)
ret, frame = video.read()
if not ret:
raise ValueError('Frame not successfully read.')
# Initialize parameters for ARTag detection
aruco_dict = aruco.Dictionary_get(aruco.DICT_4X4_100)
# aruco_dict = cv2.aruco.getPredefinedDictionary(cv2.aruco.DICT_4X4_100) <= newer cv2 version syntax
parameters = aruco.DetectorParameters_create()
# parameters = cv2.aruco.DetectorParameters() <= newer cv2 version syntax
parameters.adaptiveThreshConstant = 20
parameters.adaptiveThreshWinSizeMax = 20
parameters.adaptiveThreshWinSizeStep = 6
parameters.minMarkerPerimeterRate = .02
parameters.polygonalApproxAccuracyRate = .15
parameters.perspectiveRemovePixelPerCell = 10
parameters.perspectiveRemoveIgnoredMarginPerCell = .3
parameters.minDistanceToBorder = 0
detector = cv2.aruco.ArucoDetector(aruco_dict, parameters)
# Find ARTag coordinates in query image and reformat data to match reference coordinates
# Detect the markers
# corners, ids, rejectedImgPoints = detector.detectMarkers(frame) <= newer cv2 version syntax
corners, ids, rejectedImgPoints = aruco.detectMarkers(frame, aruco_dict, parameters=parameters)
avg = [np.average(x, axis = 1) for x in corners]
flat_corners = [item for sublist in avg for item in sublist]
flat_ids = [item for sublist in ids for item in sublist]
pair = sorted(zip(flat_ids, flat_corners))
coords = np.array([x[1] for x in pair]).astype(int)
outfile = open(outfile, "w")
coords_str = '\n'.join(' '.join(map(str, row)) for row in coords)
outfile.write(coords_str)
return outfile
def warp(frame, coord1):
"""
Inputs
frame -- query image to label
coord1 -- ARTag coordinates for reference image, used as points of reference to apply homography
Outputs
M -- 3x3 transformation matrix, used to warp query image to closely match reference image
result -- warped query image. We do this warping to consistently label ROIs.
"""
h, w = frame.shape
# Initialize parameters for ARTag detection
aruco_dict = aruco.Dictionary_get(aruco.DICT_4X4_100)
parameters = aruco.DetectorParameters_create()
parameters.adaptiveThreshConstant = 20
parameters.adaptiveThreshWinSizeMax = 20
parameters.adaptiveThreshWinSizeStep = 6
parameters.minMarkerPerimeterRate = .02
parameters.polygonalApproxAccuracyRate = .15
parameters.perspectiveRemovePixelPerCell = 10
parameters.perspectiveRemoveIgnoredMarginPerCell = .3
parameters.minDistanceToBorder = 0
# Find ARTag coordinates in query image and reformat data to match reference coordinates
corners, ids, rejectedImgPoints = aruco.detectMarkers(frame, aruco_dict, parameters=parameters)
avg = [np.average(x, axis = 1) for x in corners]
frame_markers = aruco.drawDetectedMarkers(frame.copy(), corners, ids, [0, 255, 0])
flat_corners = [item for sublist in avg for item in sublist]
flat_ids = [item for sublist in ids for item in sublist]
pair = sorted(zip(flat_ids, flat_corners))
"""
# Test
plt.imshow(frame_markers)
plt.show()"""
# Warp query image using ARTag coordinates
coord2 = np.array([x[1] for x in pair]).astype(int)
print(coord2)
# WARNING
# It is possible the number of detect ArUco tags is not 7 (which is the usual)
# We can still run the pipeline with less than 7, as long as we correctly identify
# which ArUco tag is missing. Please write software that is able to do so...
M, _ = cv2.findHomography(coord2, coord1)
result = cv2.warpPerspective(frame, M, (w,h))
return M, result
def mask(frame):
"""
Inputs
result -- warped query image
Outputs
mask -- thresholded query image keeping only bright sections in the image, to isolate the tree
structure from the background
"""
thresh = 160 # Might need to adjust this number if lighting conditions call for it (lower for dimmer arenas)
_, th1 = cv2.threshold(frame,thresh,255,cv2.THRESH_BINARY)
# Clean up mask with morphological operations
open_kernel = np.ones((8, 8), np.uint8)
dilate_kernel = np.ones((6, 6), np.uint8)
close_kernel = np.ones((6, 6), np.uint8)
mask = cv2.morphologyEx(th1, cv2.MORPH_OPEN, open_kernel)
mask = cv2.dilate(mask, dilate_kernel, iterations=1)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, close_kernel)
return mask
def nodes(mask):
"""
Inputs
mask -- thresholded query image
Outputs
new_node_centers -- coordinates of each branching point, which is the center of each roi
"""
# Skeletonize tree and build graph network
skeleton = skeletonize(mask//255).astype(np.uint16)
graph = sknw.build_sknw(skeleton)
nodes = graph.nodes()
node_centers = np.array([nodes[i]['o'] for i in nodes])
# Filter out nodes at tips of branches, keeping only nodes that define the centers of each ROI
copy = graph.copy()
for i in range(len(node_centers)):
conn = [n for n in graph.neighbors(i)]
if len(conn) < 3:
copy.remove_node(i)
new_nodes = copy.nodes()
new_node_centers = np.array([new_nodes[i]['o'] for i in new_nodes]).astype(int)
return new_node_centers
def centers(reference, ps):
"""
Inputs
reference -- coordinates of each roi center in the reference image
query -- detected coordinates of each roi center in the query image
Outputs
newpoints -- reordered query coordinates to match ordering of reference. This allows for consistent
labelling.
"""
# Find center point in query that is closest to the center point in reference
newpoints = []
for i in range(len(reference)):
min_dist = 10000
index = None
for j in range(len(ps)):
dist = np.sqrt(np.abs((reference[i][0] - ps[j][0])**2 + (reference[i][1] - ps[j][1])**2))
if (dist < min_dist):
min_dist = dist
index = j
newpoints.append(ps[index])
newpoints = np.array(newpoints)
return newpoints
# DILATION! CHANGME
def contour(mask):
"""
Inputs
mask -- thresholded query image
Outputs
cont -- rather than a skeletonized representation, return a contour image of the tree.
We will use this to find the vertices of the rois
"""
# Find and draw only the largest contour in the image. This will be the tree structure
cont = np.zeros_like(mask)
contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_NONE)
contours = max(contours, key=cv2.contourArea)
cv2.drawContours(cont, contours, -1, [255, 255, 255])
# Dilate the contours to make the lines thicker
# kernel = np.ones((3, 3), np.uint8)
# cont = cv2.dilate(cont, kernel, iterations=1)
return cont
def drawCircles(cont, newpoints):
cont_with_circles = cont.copy() # Copy of the tree image to draw on
if len(cont_with_circles.shape) == 2: # If grayscale, convert to BGR
cont_with_circles = cv2.cvtColor(cont_with_circles, cv2.COLOR_GRAY2BGR)
h, w = cont_with_circles.shape[:2]
for point in newpoints:
if not (0 <= point[0] < h and 0 <= point[1] < w):
print(f"Out-of-bounds point: {point}") # Debugging
for i in range(len(newpoints)):
cv2.circle(cont_with_circles, (newpoints[i, 1], newpoints[i, 0]), 40, (0, 0, 255), 2) # Red circle
return cont_with_circles
def vertices(cont, newpoints, Dict, Orientation):
"""
Inputs
cont -- contour image of tree structure
newpoints -- reordered centers of query rois
Dict -- dictionary mapping current labelling to match lab's labelling
Outputs
verts -- vertices of each roi, consistently ordered
"""
conn = []
for i in range(len(newpoints)):
cont_test = cont.copy()
# Points of intersection between circle and contour image represent vertices of an roi
circle = np.zeros_like(cont)
cv2.circle(circle, (newpoints[i, 1], newpoints[i, 0]), 40, [255,255,255], 2)
# # test circle on contour
# current_directory = os.getcwd()
# save_location = os.path.join(os.path.join(current_directory, "contour_and_circle"), "the_circles_new_ROI_" + str(Dict[i]) + '.png')
# print(save_location)
# cv2.circle(cont_test, (newpoints[i, 1], newpoints[i, 0]), 40, [255,255,255], 2)
# if not cv2.imwrite(save_location, cont_test):
# print("ERROR: Contour and Circle image was not saved")
inter = cv2.bitwise_and(cont, circle)
# cv2.imwrite('detect_images/bitwise_' + str(i) + '.png', inter)
# if cv2.findNonZero(inter) is None:
# print(f"[WARNING] No intersection for ROI {i} (label {Dict[i]}), center = {newpoints[i]}")
# continue
index = np.array(cv2.findNonZero(inter))
index = np.array([index[i][0] for i in range(len(index))])
# At times one point of intersection will be detected as two closely positioned points.
# Use k means to ensure we get the correct number of vertices.
kmeans = KMeans(n_clusters=6).fit(index)
centers = np.array(kmeans.cluster_centers_).astype(int)
# Connect vertices to form a convex polygon
poly = cv2.convexHull(centers)
poly = np.array([x[0] for x in poly])
# Find largest edge defined by the vertices, and reorder vertices so that edge is first
d = np.diff(poly, axis=0, append=poly[0:1])
segdists = np.sqrt((d ** 2).sum(axis=1))
index = np.argmax(segdists)
roll = np.roll(poly, -index, axis = 0)
# Reorder right junctions so they have the same labelling as left junctions
right_set = set()
for tag, ori in Orientation:
if ori == "R":
right_set.add(int(tag))
if Dict[i] in right_set:
roll = np.roll(roll, 2, axis = 0)
conn.append(roll)
print(i, Dict[i], len(poly))
# print(conn)
conn = np.array(conn)
return conn
"""
prints the (x, y) coordinates of each mouseclick
to find center coordinates: click the center of each junction and then invert the x and y to get (y,x)
Input:
event -- left mouse click
x -- x coordinate of where the mouse was placed on image
y -- y coordinate of where the mouse was placed on image
flag -- additional flags
param -- additional parameters
Print:
"Coordinates: (x, y)"
"""
def get_coordinates(event, x, y, flags, param):
"""Callback function to capture mouse click coordinates."""
if event == cv2.EVENT_LBUTTONDOWN: # Correct function signature
print(f"Coordinates: ({x}, {y})")
def main():
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument('video',
type=str,
help='The path to a video with ROIs to detect.',
)
arg_parser.add_argument('outfile',
type=str,
help='The path to a file in which to write the '
'detected ROIs.',
)
arg_parser.add_argument('-f', '--frame',
dest='frame',
type=int,
default=1,
help='The frame number in the video to use for '
'ROI detection (default=1)',
)
arg_parser.add_argument('-y', '--year',
dest='year',
type=str,
help='The year the video was taken',
)
args = arg_parser.parse_args()
print(args.year)
# Read in first frame of video as an image
if not os.path.isfile(args.video):
arg_parser.error(f'{args.video} is not a valid file.')
video = cv2.VideoCapture(args.video)
ret, frame = video.read()
if not ret:
arg_parser.error('The video only has {} frames.'.format(args.frame-1))
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
# TESTING gray
# current_directory = os.getcwd()
# save_location = os.path.join(current_directory, "grey.png")
# cv2.imwrite(save_location, gray)
# Load in relevant reference coordinates
coord1 = np.array(np.loadtxt(f"templates/tag_coordinates_{args.year}.txt")).astype(int) # Aruco detection (preferably all 7 visible)
if args.year == "2021":
reference = np.array(np.loadtxt("templates/center_coordinates_2021.txt")).astype(int) # Center coordinates. data depends on year
csv_file = "templates/dictionary_2021.csv"
elif args.year == "2023":
reference = np.array(np.loadtxt("templates/center_coordinates_2023.txt")).astype(int) # Center coordinates. data depends on year
csv_file = "templates/dictionary_2023.csv"
elif args.year == "2025":
reference = np.array(np.loadtxt("templates/center_coordinates_2025.txt")).astype(int)
csv_file = "templates/dictionary_2025.csv"
Dict = {}
Orientation = []
with open(csv_file, 'r') as file:
reader = csv.reader(file)
# Skip the headerscripts
next(reader)
# make Dict
for index, row in enumerate(reader):
Dict[int(index)] = int(row[0])
Orientation.append(tuple(row))
M, result = warp(gray, coord1)
# TESTING result
# current_directory = os.getcwd()
# save_location = os.path.join(current_directory, "result.png")
# cv2.imwrite(save_location, result)
frame_mask = mask(result)
query = nodes(frame_mask)
newpoints = centers(reference, query)
# Testing
# print(newpoints)
# for i in range(len(newpoints)):
# cv2.circle(result,(newpoints[i][1],newpoints[i][0]),3,[255,0,0],3)
# plt.imshow(result)
# plt.show()
cont = contour(frame_mask)
# contour picture testing
current_directory = os.getcwd()
save_location = os.path.join(current_directory, "contour.png")
cv2.imwrite(save_location, cont)
verts = vertices(cont, newpoints, Dict, Orientation)
# ## add circles onto the contour image
result = drawCircles(cont, newpoints)
image_path = os.path.join(current_directory, "cont_with_circles.png")
cv2.imwrite(image_path, result)
# ## outputs coordinates when you click on the image
# cv2.imshow("Image", frame) # frame
# cv2.setMouseCallback("Image", get_coordinates)
# cv2.waitKey(0)
# cv2.destoryAllWindows()
# Undo transformation to get vertices coordinates in original frame
# print("starting to print verts")
print(verts)
pts2 = np.array(verts, np.float32)
polys = np.array(cv2.perspectiveTransform(pts2, np.linalg.pinv(M))).astype(int)
# Testing
# print(polys)
# for i in range(len(polys)):
# for j in range(6):
# cv2.circle(frame,(polys[i][j][0],polys[i][j][1]),3,[255,0,0],3)
# plt.imshow(frame)
# plt.show()
# Save vertices to outfile
rois = [bbox.BBox.from_verts(poly, 3) for poly in polys]
bbox.save_rois(rois, args.outfile)
if __name__ == '__main__':
# create_aruco_coords(file, outfile)
main()