Skip to content

Commit c19614f

Browse files
Merge pull request #59 from welli7ngton/mnt/annotate_display_methods
MNT: Add type annotations and improve function signatures.
2 parents d42167e + 16efc50 commit c19614f

File tree

2 files changed

+131
-87
lines changed

2 files changed

+131
-87
lines changed

botcity/core/bot.py

Lines changed: 92 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import subprocess
77
import time
88
import webbrowser
9-
from typing import Union, Tuple, Optional, List
9+
from typing import Union, Tuple, Optional, List, Dict, Generator, Any
1010

11+
from numpy import ndarray
1112

1213
import pyperclip
1314
from botcity.base import BaseBot, State
@@ -121,7 +122,7 @@ def app(self, app: Union["Application", "WindowSpecification"]):
121122
# Display
122123
##########
123124

124-
def add_image(self, label, path):
125+
def add_image(self, label: str, path: str) -> None:
125126
"""
126127
Add an image into the state image map.
127128
@@ -131,7 +132,7 @@ def add_image(self, label, path):
131132
"""
132133
self.state.map_images[label] = path
133134

134-
def get_image_from_map(self, label):
135+
def get_image_from_map(self, label: str) -> Image.Image:
135136
"""
136137
Return an image from teh state image map.
137138
@@ -149,25 +150,25 @@ def get_image_from_map(self, label):
149150

150151
def find_multiple(
151152
self,
152-
labels,
153-
x=None,
154-
y=None,
155-
width=None,
156-
height=None,
153+
labels: List,
154+
x: int = 0,
155+
y: int = 0,
156+
width: Optional[int] = None,
157+
height: Optional[int] = None,
157158
*,
158-
threshold=None,
159-
matching=0.9,
160-
waiting_time=10000,
161-
best=True,
162-
grayscale=False,
163-
):
159+
threshold: Optional[int] = None,
160+
matching: float = 0.9,
161+
waiting_time: int = 10000,
162+
best: bool = True,
163+
grayscale: bool = False,
164+
) -> Dict:
164165
"""
165166
Find multiple elements defined by label on screen until a timeout happens.
166167
167168
Args:
168169
labels (list): A list of image identifiers
169-
x (int, optional): Search region start position x. Defaults to 0.
170-
y (int, optional): Search region start position y. Defaults to 0.
170+
x (int): Search region start position x. Defaults to 0.
171+
y (int): Search region start position y. Defaults to 0.
171172
width (int, optional): Search region width. Defaults to screen width.
172173
height (int, optional): Search region height. Defaults to screen height.
173174
threshold (int, optional): The threshold to be applied when doing grayscale search.
@@ -190,8 +191,6 @@ def _to_dict(lbs, elems):
190191
return {k: v for k, v in zip(lbs, elems)}
191192

192193
screen_w, screen_h = self._fix_display_size()
193-
x = x or 0
194-
y = y or 0
195194
w = width or screen_w
196195
h = height or screen_h
197196

@@ -253,7 +252,14 @@ def _fix_display_size(self) -> Tuple[int, int]:
253252

254253
return int(width * 2), int(height * 2)
255254

256-
def _find_multiple_helper(self, haystack, region, confidence, grayscale, needle):
255+
def _find_multiple_helper(
256+
self,
257+
haystack: Image.Image,
258+
region: Tuple[int, int, int, int],
259+
confidence: float,
260+
grayscale: bool,
261+
needle: Union[Image.Image, ndarray, str],
262+
) -> Union[cv2find.Box, None]:
257263
ele = cv2find.locate_all_opencv(
258264
needle, haystack, region=region, confidence=confidence, grayscale=grayscale
259265
)
@@ -265,18 +271,18 @@ def _find_multiple_helper(self, haystack, region, confidence, grayscale, needle)
265271

266272
def find(
267273
self,
268-
label,
269-
x=None,
270-
y=None,
271-
width=None,
272-
height=None,
274+
label: str,
275+
x: Optional[int] = None,
276+
y: Optional[int] = None,
277+
width: Optional[int] = None,
278+
height: Optional[int] = None,
273279
*,
274-
threshold=None,
275-
matching=0.9,
276-
waiting_time=10000,
277-
best=True,
278-
grayscale=False,
279-
):
280+
threshold: Optional[int] = None,
281+
matching: float = 0.9,
282+
waiting_time: int = 10000,
283+
best: bool = True,
284+
grayscale: bool = False,
285+
) -> Union[cv2find.Box, None]:
280286
"""
281287
Find an element defined by label on screen until a timeout happens.
282288
@@ -315,18 +321,18 @@ def find(
315321

316322
def find_until(
317323
self,
318-
label,
319-
x=None,
320-
y=None,
321-
width=None,
322-
height=None,
324+
label: str,
325+
x: Optional[int] = None,
326+
y: Optional[int] = None,
327+
width: Optional[int] = None,
328+
height: Optional[int] = None,
323329
*,
324-
threshold=None,
325-
matching=0.9,
326-
waiting_time=10000,
327-
best=True,
328-
grayscale=False,
329-
):
330+
threshold: Optional[int] = None,
331+
matching: float = 0.9,
332+
waiting_time: int = 10000,
333+
best: bool = True,
334+
grayscale: bool = False,
335+
) -> Union[cv2find.Box, None]:
330336
"""
331337
Find an element defined by label on screen until a timeout happens.
332338
@@ -399,17 +405,17 @@ def find_until(
399405

400406
def find_all(
401407
self,
402-
label,
403-
x=None,
404-
y=None,
405-
width=None,
406-
height=None,
408+
label: str,
409+
x: Optional[int] = None,
410+
y: Optional[int] = None,
411+
width: Optional[int] = None,
412+
height: Optional[int] = None,
407413
*,
408-
threshold=None,
409-
matching=0.9,
410-
waiting_time=10000,
411-
grayscale=False,
412-
):
414+
threshold: Optional[int] = None,
415+
matching: float = 0.9,
416+
waiting_time: int = 10000,
417+
grayscale: bool = False,
418+
) -> Generator[cv2find.Box, Any, None]:
413419
"""
414420
Find all elements defined by label on screen until a timeout happens.
415421
@@ -433,7 +439,9 @@ def find_all(
433439
None if not found.
434440
"""
435441

436-
def deduplicate(elems):
442+
def deduplicate(
443+
elems: list[Generator[cv2find.Box, Any, None]]
444+
) -> list[Generator[cv2find.Box, Any, None]]:
437445
def find_same(item, items):
438446
x_start = item.left
439447
x_end = item.left + item.width
@@ -504,17 +512,17 @@ def find_same(item, items):
504512

505513
def find_text(
506514
self,
507-
label,
508-
x=None,
509-
y=None,
510-
width=None,
511-
height=None,
515+
label: str,
516+
x: Optional[int] = None,
517+
y: Optional[int] = None,
518+
width: Optional[int] = None,
519+
height: Optional[int] = None,
512520
*,
513-
threshold=None,
514-
matching=0.9,
515-
waiting_time=10000,
516-
best=True,
517-
):
521+
threshold: Optional[int] = None,
522+
matching: float = 0.9,
523+
waiting_time: int = 10000,
524+
best: bool = True,
525+
) -> Union[cv2find.Box, None]:
518526
"""
519527
Find an element defined by label on screen until a timeout happens.
520528
@@ -549,7 +557,9 @@ def find_text(
549557
grayscale=True,
550558
)
551559

552-
def find_process(self, name: str = None, pid: str = None) -> Process:
560+
def find_process(
561+
self, name: Optional[str] = None, pid: Optional[str] = None
562+
) -> Union[Process, None]:
553563
"""
554564
Find a process by name or PID
555565
@@ -570,7 +580,7 @@ def find_process(self, name: str = None, pid: str = None) -> Process:
570580
pass
571581
return None
572582

573-
def terminate_process(self, process: Process):
583+
def terminate_process(self, process: Process) -> None:
574584
"""
575585
Terminate the process via the received Process object.
576586
@@ -582,7 +592,7 @@ def terminate_process(self, process: Process):
582592
if process.is_running():
583593
raise Exception("Terminate process failed")
584594

585-
def get_last_element(self):
595+
def get_last_element(self) -> cv2find.Box:
586596
"""
587597
Return the last element found.
588598
@@ -681,8 +691,15 @@ def save_screenshot(self, path: str) -> None:
681691
self.screenshot(path)
682692

683693
def get_element_coords(
684-
self, label, x=None, y=None, width=None, height=None, matching=0.9, best=True
685-
):
694+
self,
695+
label: str,
696+
x: Optional[int] = None,
697+
y: Optional[int] = None,
698+
width: Optional[int] = None,
699+
height: Optional[int] = None,
700+
matching: float = 0.9,
701+
best: bool = True,
702+
) -> Union[Tuple[int, int], Tuple[None, None]]:
686703
"""
687704
Find an element defined by label on screen and returns its coordinates.
688705
@@ -736,7 +753,14 @@ def get_element_coords(
736753
return ele.left, ele.top
737754

738755
def get_element_coords_centered(
739-
self, label, x=None, y=None, width=None, height=None, matching=0.9, best=True
756+
self,
757+
label: str,
758+
x: Optional[int] = None,
759+
y: Optional[int] = None,
760+
width: Optional[int] = None,
761+
height: Optional[int] = None,
762+
matching: float = 0.9,
763+
best: bool = True,
740764
):
741765
"""
742766
Find an element defined by label on screen and returns its centered coordinates.

botcity/core/cv2find.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,16 @@
3030
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
3131
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3232
"""
33+
3334
import collections
3435
import cv2
3536
import numpy
37+
from PIL.Image import Image
38+
from typing import Union, Tuple, Optional, Generator, Any
3639

37-
RUNNING_CV_2 = cv2.__version__[0] < '3'
40+
RUNNING_CV_2 = cv2.__version__[0] < "3"
3841

39-
Box = collections.namedtuple('Box', 'left top width height')
42+
Box = collections.namedtuple("Box", "left top width height")
4043

4144
if RUNNING_CV_2:
4245
LOAD_COLOR = cv2.CV_LOAD_IMAGE_COLOR
@@ -46,7 +49,9 @@
4649
LOAD_GRAYSCALE = cv2.IMREAD_GRAYSCALE
4750

4851

49-
def _load_cv2(img, grayscale=False):
52+
def _load_cv2(
53+
img: Union[Image, numpy.ndarray, str], grayscale: bool = False
54+
) -> numpy.ndarray:
5055
"""
5156
TODO
5257
"""
@@ -66,28 +71,37 @@ def _load_cv2(img, grayscale=False):
6671
else:
6772
img_cv = cv2.imread(img, LOAD_COLOR)
6873
if img_cv is None:
69-
raise IOError("Failed to read %s because file is missing, "
70-
"has improper permissions, or is an "
71-
"unsupported or invalid format" % img)
74+
raise IOError(
75+
"Failed to read %s because file is missing, "
76+
"has improper permissions, or is an "
77+
"unsupported or invalid format" % img
78+
)
7279
elif isinstance(img, numpy.ndarray):
7380
# don't try to convert an already-gray image to gray
7481
if grayscale and len(img.shape) == 3: # and img.shape[2] == 3:
7582
img_cv = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
7683
else:
7784
img_cv = img
78-
elif hasattr(img, 'convert'):
85+
elif hasattr(img, "convert"):
7986
# assume its a PIL.Image, convert to cv format
80-
img_array = numpy.array(img.convert('RGB'))
87+
img_array = numpy.array(img.convert("RGB"))
8188
img_cv = img_array[:, :, ::-1].copy() # -1 does RGB -> BGR
8289
if grayscale:
8390
img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
8491
else:
85-
raise TypeError('expected an image filename, OpenCV numpy array, or PIL image')
92+
raise TypeError("expected an image filename, OpenCV numpy array, or PIL image")
8693
return img_cv
8794

8895

89-
def locate_all_opencv(needle_image, haystack_image, grayscale=False, limit=10000, region=None, step=1,
90-
confidence=0.999):
96+
def locate_all_opencv(
97+
needle_image: Union[Image, numpy.ndarray, str],
98+
haystack_image: Union[Image, numpy.ndarray, str],
99+
grayscale: bool = False,
100+
limit: int = 10000,
101+
region: Optional[Tuple[int, int, int, int]] = None,
102+
step: int = 1,
103+
confidence: float = 0.999,
104+
) -> Generator[Box, Any, None]:
91105
"""
92106
TODO - rewrite this
93107
faster but more memory-intensive than pure python
@@ -107,15 +121,19 @@ def locate_all_opencv(needle_image, haystack_image, grayscale=False, limit=10000
107121

108122
if region:
109123
haystack_image = haystack_image[
110-
region[1]:region[1] + region[3],
111-
region[0]:region[0] + region[2]
112-
]
124+
region[1]: region[1] + region[3], region[0]: region[0] + region[2]
125+
]
113126
else:
114-
region = (0, 0) # full image; these values used in the yield statement
115-
if (haystack_image.shape[0] < needle_image.shape[0] or
116-
haystack_image.shape[1] < needle_image.shape[1]):
127+
region = (0, 0, 0, 0) # full image; these values used in the yield statement
128+
129+
if (
130+
haystack_image.shape[0] < needle_image.shape[0]
131+
or haystack_image.shape[1] < needle_image.shape[1]
132+
):
117133
# avoid semi-cryptic OpenCV error below if bad size
118-
raise ValueError('needle dimension(s) exceed the haystack image or region dimensions')
134+
raise ValueError(
135+
"needle dimension(s) exceed the haystack image or region dimensions"
136+
)
119137

120138
if step == 2:
121139
confidence *= 0.95
@@ -138,6 +156,8 @@ def locate_all_opencv(needle_image, haystack_image, grayscale=False, limit=10000
138156
matchy = matches[0] * step + region[1]
139157

140158
# Order results before sending back
141-
ordered = sorted(zip(matchx, matchy), key=lambda p: result[p[1]][p[0]], reverse=True)
159+
ordered = sorted(
160+
zip(matchx, matchy), key=lambda p: result[p[1]][p[0]], reverse=True
161+
)
142162
for x, y in ordered:
143163
yield Box(x, y, needle_width, needle_height)

0 commit comments

Comments
 (0)