Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 92 additions & 68 deletions botcity/core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import subprocess
import time
import webbrowser
from typing import Union, Tuple, Optional, List
from typing import Union, Tuple, Optional, List, Dict, Generator, Any

from numpy import ndarray

import pyperclip
from botcity.base import BaseBot, State
Expand Down Expand Up @@ -121,7 +122,7 @@ def app(self, app: Union["Application", "WindowSpecification"]):
# Display
##########

def add_image(self, label, path):
def add_image(self, label: str, path: str) -> None:
"""
Add an image into the state image map.

Expand All @@ -131,7 +132,7 @@ def add_image(self, label, path):
"""
self.state.map_images[label] = path

def get_image_from_map(self, label):
def get_image_from_map(self, label: str) -> Image.Image:
"""
Return an image from teh state image map.

Expand All @@ -149,25 +150,25 @@ def get_image_from_map(self, label):

def find_multiple(
self,
labels,
x=None,
y=None,
width=None,
height=None,
labels: List,
x: int = 0,
y: int = 0,
width: Optional[int] = None,
height: Optional[int] = None,
*,
threshold=None,
matching=0.9,
waiting_time=10000,
best=True,
grayscale=False,
):
threshold: Optional[int] = None,
matching: float = 0.9,
waiting_time: int = 10000,
best: bool = True,
grayscale: bool = False,
) -> Dict:
"""
Find multiple elements defined by label on screen until a timeout happens.

Args:
labels (list): A list of image identifiers
x (int, optional): Search region start position x. Defaults to 0.
y (int, optional): Search region start position y. Defaults to 0.
x (int): Search region start position x. Defaults to 0.
y (int): Search region start position y. Defaults to 0.
width (int, optional): Search region width. Defaults to screen width.
height (int, optional): Search region height. Defaults to screen height.
threshold (int, optional): The threshold to be applied when doing grayscale search.
Expand All @@ -190,8 +191,6 @@ def _to_dict(lbs, elems):
return {k: v for k, v in zip(lbs, elems)}

screen_w, screen_h = self._fix_display_size()
x = x or 0
y = y or 0
w = width or screen_w
h = height or screen_h

Expand Down Expand Up @@ -253,7 +252,14 @@ def _fix_display_size(self) -> Tuple[int, int]:

return int(width * 2), int(height * 2)

def _find_multiple_helper(self, haystack, region, confidence, grayscale, needle):
def _find_multiple_helper(
self,
haystack: Image.Image,
region: Tuple[int, int, int, int],
confidence: float,
grayscale: bool,
needle: Union[Image.Image, ndarray, str],
) -> Union[cv2find.Box, None]:
ele = cv2find.locate_all_opencv(
needle, haystack, region=region, confidence=confidence, grayscale=grayscale
)
Expand All @@ -265,18 +271,18 @@ def _find_multiple_helper(self, haystack, region, confidence, grayscale, needle)

def find(
self,
label,
x=None,
y=None,
width=None,
height=None,
label: str,
x: Optional[int] = None,
y: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
*,
threshold=None,
matching=0.9,
waiting_time=10000,
best=True,
grayscale=False,
):
threshold: Optional[int] = None,
matching: float = 0.9,
waiting_time: int = 10000,
best: bool = True,
grayscale: bool = False,
) -> Union[cv2find.Box, None]:
"""
Find an element defined by label on screen until a timeout happens.

Expand Down Expand Up @@ -315,18 +321,18 @@ def find(

def find_until(
self,
label,
x=None,
y=None,
width=None,
height=None,
label: str,
x: Optional[int] = None,
y: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
*,
threshold=None,
matching=0.9,
waiting_time=10000,
best=True,
grayscale=False,
):
threshold: Optional[int] = None,
matching: float = 0.9,
waiting_time: int = 10000,
best: bool = True,
grayscale: bool = False,
) -> Union[cv2find.Box, None]:
"""
Find an element defined by label on screen until a timeout happens.

Expand Down Expand Up @@ -399,17 +405,17 @@ def find_until(

def find_all(
self,
label,
x=None,
y=None,
width=None,
height=None,
label: str,
x: Optional[int] = None,
y: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
*,
threshold=None,
matching=0.9,
waiting_time=10000,
grayscale=False,
):
threshold: Optional[int] = None,
matching: float = 0.9,
waiting_time: int = 10000,
grayscale: bool = False,
) -> Generator[cv2find.Box, Any, None]:
"""
Find all elements defined by label on screen until a timeout happens.

Expand All @@ -433,7 +439,9 @@ def find_all(
None if not found.
"""

def deduplicate(elems):
def deduplicate(
elems: list[Generator[cv2find.Box, Any, None]]
) -> list[Generator[cv2find.Box, Any, None]]:
def find_same(item, items):
x_start = item.left
x_end = item.left + item.width
Expand Down Expand Up @@ -504,17 +512,17 @@ def find_same(item, items):

def find_text(
self,
label,
x=None,
y=None,
width=None,
height=None,
label: str,
x: Optional[int] = None,
y: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
*,
threshold=None,
matching=0.9,
waiting_time=10000,
best=True,
):
threshold: Optional[int] = None,
matching: float = 0.9,
waiting_time: int = 10000,
best: bool = True,
) -> Union[cv2find.Box, None]:
"""
Find an element defined by label on screen until a timeout happens.

Expand Down Expand Up @@ -549,7 +557,9 @@ def find_text(
grayscale=True,
)

def find_process(self, name: str = None, pid: str = None) -> Process:
def find_process(
self, name: Optional[str] = None, pid: Optional[str] = None
) -> Union[Process, None]:
"""
Find a process by name or PID

Expand All @@ -570,7 +580,7 @@ def find_process(self, name: str = None, pid: str = None) -> Process:
pass
return None

def terminate_process(self, process: Process):
def terminate_process(self, process: Process) -> None:
"""
Terminate the process via the received Process object.

Expand All @@ -582,7 +592,7 @@ def terminate_process(self, process: Process):
if process.is_running():
raise Exception("Terminate process failed")

def get_last_element(self):
def get_last_element(self) -> cv2find.Box:
"""
Return the last element found.

Expand Down Expand Up @@ -681,8 +691,15 @@ def save_screenshot(self, path: str) -> None:
self.screenshot(path)

def get_element_coords(
self, label, x=None, y=None, width=None, height=None, matching=0.9, best=True
):
self,
label: str,
x: Optional[int] = None,
y: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
matching: float = 0.9,
best: bool = True,
) -> Union[Tuple[int, int], Tuple[None, None]]:
"""
Find an element defined by label on screen and returns its coordinates.

Expand Down Expand Up @@ -736,7 +753,14 @@ def get_element_coords(
return ele.left, ele.top

def get_element_coords_centered(
self, label, x=None, y=None, width=None, height=None, matching=0.9, best=True
self,
label: str,
x: Optional[int] = None,
y: Optional[int] = None,
width: Optional[int] = None,
height: Optional[int] = None,
matching: float = 0.9,
best: bool = True,
):
"""
Find an element defined by label on screen and returns its centered coordinates.
Expand Down
58 changes: 39 additions & 19 deletions botcity/core/cv2find.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,16 @@
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""

import collections
import cv2
import numpy
from PIL.Image import Image
from typing import Union, Tuple, Optional, Generator, Any

RUNNING_CV_2 = cv2.__version__[0] < '3'
RUNNING_CV_2 = cv2.__version__[0] < "3"

Box = collections.namedtuple('Box', 'left top width height')
Box = collections.namedtuple("Box", "left top width height")

if RUNNING_CV_2:
LOAD_COLOR = cv2.CV_LOAD_IMAGE_COLOR
Expand All @@ -46,7 +49,9 @@
LOAD_GRAYSCALE = cv2.IMREAD_GRAYSCALE


def _load_cv2(img, grayscale=False):
def _load_cv2(
img: Union[Image, numpy.ndarray, str], grayscale: bool = False
) -> numpy.ndarray:
"""
TODO
"""
Expand All @@ -66,28 +71,37 @@ def _load_cv2(img, grayscale=False):
else:
img_cv = cv2.imread(img, LOAD_COLOR)
if img_cv is None:
raise IOError("Failed to read %s because file is missing, "
"has improper permissions, or is an "
"unsupported or invalid format" % img)
raise IOError(
"Failed to read %s because file is missing, "
"has improper permissions, or is an "
"unsupported or invalid format" % img
)
elif isinstance(img, numpy.ndarray):
# don't try to convert an already-gray image to gray
if grayscale and len(img.shape) == 3: # and img.shape[2] == 3:
img_cv = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
img_cv = img
elif hasattr(img, 'convert'):
elif hasattr(img, "convert"):
# assume its a PIL.Image, convert to cv format
img_array = numpy.array(img.convert('RGB'))
img_array = numpy.array(img.convert("RGB"))
img_cv = img_array[:, :, ::-1].copy() # -1 does RGB -> BGR
if grayscale:
img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2GRAY)
else:
raise TypeError('expected an image filename, OpenCV numpy array, or PIL image')
raise TypeError("expected an image filename, OpenCV numpy array, or PIL image")
return img_cv


def locate_all_opencv(needle_image, haystack_image, grayscale=False, limit=10000, region=None, step=1,
confidence=0.999):
def locate_all_opencv(
needle_image: Union[Image, numpy.ndarray, str],
haystack_image: Union[Image, numpy.ndarray, str],
grayscale: bool = False,
limit: int = 10000,
region: Optional[Tuple[int, int, int, int]] = None,
step: int = 1,
confidence: float = 0.999,
) -> Generator[Box, Any, None]:
"""
TODO - rewrite this
faster but more memory-intensive than pure python
Expand All @@ -107,15 +121,19 @@ def locate_all_opencv(needle_image, haystack_image, grayscale=False, limit=10000

if region:
haystack_image = haystack_image[
region[1]:region[1] + region[3],
region[0]:region[0] + region[2]
]
region[1]: region[1] + region[3], region[0]: region[0] + region[2]
]
else:
region = (0, 0) # full image; these values used in the yield statement
if (haystack_image.shape[0] < needle_image.shape[0] or
haystack_image.shape[1] < needle_image.shape[1]):
region = (0, 0, 0, 0) # full image; these values used in the yield statement

if (
haystack_image.shape[0] < needle_image.shape[0]
or haystack_image.shape[1] < needle_image.shape[1]
):
# avoid semi-cryptic OpenCV error below if bad size
raise ValueError('needle dimension(s) exceed the haystack image or region dimensions')
raise ValueError(
"needle dimension(s) exceed the haystack image or region dimensions"
)

if step == 2:
confidence *= 0.95
Expand All @@ -138,6 +156,8 @@ def locate_all_opencv(needle_image, haystack_image, grayscale=False, limit=10000
matchy = matches[0] * step + region[1]

# Order results before sending back
ordered = sorted(zip(matchx, matchy), key=lambda p: result[p[1]][p[0]], reverse=True)
ordered = sorted(
zip(matchx, matchy), key=lambda p: result[p[1]][p[0]], reverse=True
)
for x, y in ordered:
yield Box(x, y, needle_width, needle_height)
Loading