Skip to content

Commit b10fb2b

Browse files
10-fold computation speed-up with numba and other optimizations
+ documentation
1 parent d017a98 commit b10fb2b

File tree

6 files changed

+83
-38
lines changed

6 files changed

+83
-38
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ This is only checked when a new piece appears, so you need to hold the key.
3535
2 - no hard drop
3636
3 - let the piece land on its own, the bot is always scared
3737
Number of computing paths for the next piece:
38-
z, x, c - 1, 3, 5
38+
z, x, c - 1, 4, 8
3939
n - try to clean the field
4040
m - disable cleaning mode (focus on getting tetrises)

config.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1-
# pieces are encoded as
2-
# 0 - line, 1 - square, 2 - T(flip), 3 - |__, 4 - __|, 5 - -|_,6 - _|-
1+
import numpy as np
32
from src.display_consts import DisplayConsts
43

4+
5+
# pieces are encoded as
6+
# 0 - line, 1 - square, 2 - T(flip), 3 - |__, 4 - __|, 5 - -|_,6 - _|-
57
PIECE_NAMES = ['line', 'square', 'T(flip)', '|__', '__|', '-|_', '_|-']
68

79

@@ -10,7 +12,7 @@ def name_piece(piece: int) -> str:
1012

1113

1214
# in BGR
13-
original_colors = [(0, 0, 0) for _ in range(7)]
15+
original_colors = np.zeros((7, 3), np.int)
1416
original_colors[0] = (230, 228, 180)
1517
original_colors[1] = (182, 228, 247)
1618
original_colors[2] = (177, 99, 140)
@@ -20,7 +22,7 @@ def name_piece(piece: int) -> str:
2022
original_colors[6] = (171, 240, 177)
2123

2224
# tetr.io colors in RGB
23-
tetrio_colors = [(0, 0, 0) for _ in range(7)]
25+
tetrio_colors = np.zeros((7, 3), np.int)
2426
tetrio_colors[0] = (36, 214, 150)
2527
tetrio_colors[1] = (210, 171, 42)
2628
tetrio_colors[2] = (212, 67, 195)
@@ -50,9 +52,9 @@ def name_piece(piece: int) -> str:
5052
'debug status': 1, # greater is more information, 0 is none
5153
'key press delay': 0.02, # increase if facing misclicks, decrease to go faster
5254
'tetrio garbage': True,
53-
'starting choices for 2nd': 3,
55+
'starting choices for 2nd': 8,
5456
# if true, looks at another frame to check for correct piece placement
55-
# reduces speed, improves robustness
57+
# reduces speed, increases robustness
5658
'confirm placement': True,
5759
'play safe': False, # ai is even more robust
5860
'play for survival': False, # if true, starts in 'cleaning' mode
@@ -75,7 +77,7 @@ def name_piece(piece: int) -> str:
7577

7678

7779
def configure_fast():
78-
CONFIG['starting choices for 2nd'] = 4
80+
CONFIG['starting choices for 2nd'] = 8
7981
CONFIG['confirm placement'] = False
8082
CONFIG['play for survival'] = True
8183
CONFIG['override delay'] = True

src/AI_main.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import List
22

3+
from numba import jit
4+
35
from find_landings import all_landings
46
import numpy as np
57
from direct_keys import *
@@ -49,6 +51,7 @@ def clear_line(field):
4951
return field, full_cnt
5052

5153
@staticmethod
54+
@jit(nopython=True)
5255
def find_roofs(field: np.array) -> (int, int, np.array, int):
5356
"""
5457
finds blank squares under landed pieces
@@ -70,6 +73,7 @@ def find_roofs(field: np.array) -> (int, int, np.array, int):
7073
return blank_cnt, int(np.max(tops[:, 0])), tops[:, 0], blank_depth
7174

7275
@staticmethod
76+
@jit(nopython=True)
7377
def almost_full_line(field):
7478
score = 0
7579
for i in range(len(field)):
@@ -81,6 +85,7 @@ def almost_full_line(field):
8185
return score
8286

8387
@staticmethod
88+
@jit(nopython=True)
8489
def find_pit(field, tops):
8590
gap_idx = []
8691
for i in range(len(field)):
@@ -104,6 +109,7 @@ def find_pit(field, tops):
104109
return max_pit_h
105110

106111
@staticmethod
112+
@jit(nopython=True)
107113
def find_hole(tops):
108114
cnt_hole = 0
109115
tops = np.insert(tops, 0, 20)
@@ -223,6 +229,14 @@ def choose_action(self, field: np.array, piece_idx, can_hold) -> Position:
223229
return result
224230

225231
def choose_action_depth2(self, field: np.array, piece_idx: int, next_piece: int, can_hold: bool) -> Position:
232+
"""
233+
finds best action considering the next piece as well
234+
:param field:
235+
:param piece_idx:
236+
:param next_piece:
237+
:param can_hold:
238+
:return:
239+
"""
226240
if self.choices_for_2nd == 1:
227241
# can simplify
228242
return self.choose_action(field, piece_idx, can_hold)
@@ -249,14 +263,18 @@ def choose_action_depth2(self, field: np.array, piece_idx: int, next_piece: int,
249263
sub_score_hold = self.calc_best(results[i].field, piece_idx)[0].score
250264
results[i].next_score = max(sub_score, sub_score_hold)
251265

252-
# sort by total score, prioritizing tetris on current turn
253-
# (not another move, and then tetris)
254-
results.sort(key=lambda x: x.next_score + x.score + 1000 * x.expect_tetris, reverse=True)
255-
if results[0].piece == self.held_piece and self.held_piece != piece_idx:
266+
# take best by total score, prioritizing tetris on current turn
267+
# (instead of another move, and then tetris)
268+
optimal = max(results, key=lambda x: x.next_score + x.score + 1000 * x.expect_tetris)
269+
if optimal.piece == self.held_piece and self.held_piece != piece_idx:
256270
self.hold_piece(piece_idx)
257-
return results[0]
271+
return optimal
258272

259273
def place_piece(self, piece: int, rotation: int, x_pos: int, height: int, rot_now=0, x_pos_now=3, depth=0):
274+
"""
275+
puts the piece into correct position (before lowering)
276+
optionally verifies placement
277+
"""
260278
if depth == 3:
261279
if CONFIG['debug status'] >= 1:
262280
print('depth 3 reached in place_piece')
@@ -322,7 +340,7 @@ def runtime_tuning(self):
322340
2 - medium speed, no hard placing, turn on at level 6
323341
3 - for the late game, always scared
324342
control number of paths for the next piece:
325-
z, x, c - 1, 3, 5
343+
z, x, c - 1, 4, 8
326344
n - try to clean the field
327345
m - disable cleaning mode (try to get tetrises)
328346
@@ -342,9 +360,9 @@ def runtime_tuning(self):
342360
if keyboard.is_pressed('z'):
343361
self.choices_for_2nd = 1
344362
elif keyboard.is_pressed('x'):
345-
self.choices_for_2nd = 3
363+
self.choices_for_2nd = 4
346364
elif keyboard.is_pressed('c'):
347-
self.choices_for_2nd = 5
365+
self.choices_for_2nd = 8
348366

349367
if keyboard.is_pressed('n'):
350368
self.clearing = True

src/find_landings.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
from typing import List
33

44
import numpy as np
5+
from numba import jit
56

67
from position import Position
78
from figures import array_of_figures as pieces
89

910

11+
@jit(nopython=True)
1012
def check_collision(field, piece, piece_pos, piece_idx):
1113
r = 4
1214
if piece_idx != 0:
15+
# only the line takes 4x4 grid, so for other pieces only check 3x3
1316
r -= 1
1417
for i in range(r):
1518
for j in range(r):
@@ -20,18 +23,28 @@ def check_collision(field, piece, piece_pos, piece_idx):
2023
return False
2124

2225

23-
def land(field: np.array, piece: np.array, pos_now: List[int], piece_idx: int) -> np.array:
24-
res = deepcopy(field)
25-
while not check_collision(res, piece, pos_now, piece_idx):
26+
@jit(nopython=True)
27+
def land(field: np.array, piece: np.array, x_pos: int, piece_idx: int) -> np.array:
28+
"""
29+
simulates a piece falling from a set position onto the field
30+
helps predict the outcome of an action
31+
:param field: np.array, WILL BE MODIFIED
32+
:param piece: array that tells the shape of the piece (with rotation applied)
33+
:param x_pos: horizontal position of the piece
34+
:param piece_idx: piece index
35+
:return: resulting field
36+
"""
37+
pos_now = [0, x_pos]
38+
while not check_collision(field, piece, pos_now, piece_idx):
2639
pos_now[0] += 1
2740
if pos_now[0] == 0:
2841
return None
2942
pos_now[0] -= 1
3043
for i in range(4):
3144
for j in range(4):
3245
if i + pos_now[0] < len(field) and j + pos_now[1] < len(field[0]):
33-
res[i + pos_now[0]][j + pos_now[1]] += piece[i][j]
34-
return res
46+
field[i + pos_now[0]][j + pos_now[1]] += piece[i][j]
47+
return field
3548

3649

3750
def all_landings(field: np.array, piece_index: int) -> List[Position]:
@@ -44,7 +57,7 @@ def all_landings(field: np.array, piece_index: int) -> List[Position]:
4457
results = []
4558
for rotation in range(len(pieces[piece_index])):
4659
for x_pos in range(-3, 10):
47-
res = land(field, pieces[piece_index][rotation], [0, x_pos], piece_index)
60+
res = land(deepcopy(field), pieces[piece_index][rotation], x_pos, piece_index)
4861
if res is not None:
4962
results.append(Position(res, rotation, x_pos, piece_index))
5063
return results

src/main.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,12 @@ def main():
1111
can_hold_flag = True
1212
expected_rwd = 0
1313
ai = AI()
14-
placement = None
14+
position = None
15+
16+
# call jit-compiling functions to compile them
17+
ai.calc_best(np.zeros((20, 10), dtype=np.int), 0)
18+
field, _ = get_field()
19+
print("Compilation complete")
1520

1621
# infinite playing cycle
1722
while True:
@@ -30,14 +35,14 @@ def main():
3035
if ai.held_piece == -1:
3136
ai.hold_piece(piece_idx)
3237
can_hold_flag = False
33-
time.sleep(0.2)
3438
continue
39+
3540
# shenanigans for better parsing of the original game
3641
if CONFIG['game'] == 'original':
37-
if placement is not None and placement.expect_tetris:
42+
if position is not None and position.expect_tetris:
3843
# hoping that it was not a misclick, not taking a screenshot because TETRIS blocks the view
3944
field = np.zeros((3, 10), dtype=np.int)
40-
field = np.concatenate((field, ai.clear_line(placement.field)[0]))
45+
field = np.concatenate((field, ai.clear_line(position.field)[0]))
4146
time.sleep(0.2)
4247
elif not ai.scared:
4348
field, next_piece = get_field()
@@ -53,27 +58,29 @@ def main():
5358
if CONFIG['debug status'] >= 2:
5459
print(field)
5560
print(f'current score {actual_score}')
61+
62+
# next piece is not recognized
5663
if next_piece == -1:
5764
if CONFIG['debug status'] >= 1:
5865
print("unknown next")
5966
next_piece = 1 # assume square as it is the most neutral one
6067

6168
calc_start_time = time.time()
6269
# compute best outcome
63-
placement = ai.choose_action_depth2(field[3:], piece_idx, next_piece, can_hold_flag)
70+
position = ai.choose_action_depth2(field[3:], piece_idx, next_piece, can_hold_flag)
6471

6572
if CONFIG['debug status'] >= 1:
66-
# useful info
73+
# print useful info
6774
print('calculation took', time.time() - calc_start_time)
68-
print(f'chosen placement for {name_piece(placement.piece)}: '
69-
f'({placement.rotation}, {placement.x_pos}) with score {placement.score}')
70-
print(f'next figure {name_piece(next_piece)} should give {placement.next_score}')
71-
if placement.expect_tetris:
75+
print(f'chosen placement for {name_piece(position.piece)}: '
76+
f'({position.rotation}, {position.x_pos}) with score {position.score}')
77+
print(f'next figure {name_piece(next_piece)} should give {position.next_score}')
78+
if position.expect_tetris:
7279
print('expecting TETRIS')
7380

74-
expected_rwd = ai.get_score(ai.clear_line(placement.field)[0])[0]
81+
expected_rwd = ai.get_score(ai.clear_line(position.field)[0])[0]
7582
# emulate key presses to place the piece
76-
ai.place_piece(placement.piece, placement.rotation, placement.x_pos, ai.find_roofs(placement.field)[1])
83+
ai.place_piece(position.piece, position.rotation, position.x_pos, ai.find_roofs(position.field)[1])
7784
# wait for everything to settle down
7885
ai.place_piece_delay()
7986

@@ -83,5 +90,5 @@ def main():
8390

8491

8592
if __name__ == '__main__':
86-
time.sleep(2)
93+
time.sleep(1)
8794
main()

src/scan_field.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
import matplotlib.pyplot as plt
22
import numpy as np
33
from mss import mss
4+
from numba import jit
5+
46
from config import CONFIG
57

68
screen_capture = mss()
79
monitor = {"left": 0, "top": 0, "width": CONFIG['screen width'], "height": CONFIG['screen height']}
10+
piece_colors = CONFIG['piece colors']
811

912

1013
def print_image(arr, figure_size=10):
@@ -45,11 +48,13 @@ def simplified(pixels: np.array) -> np.array:
4548
return field
4649

4750

51+
@jit(nopython=True)
4852
def cmp_pixel(p1, p2):
4953
return abs(p1[0] - p2[0]) + abs(p1[1] - p2[1]) + abs(p1[2] - p2[2])
5054

5155

52-
def get_figure_by_color(screen):
56+
@jit(nopython=True)
57+
def get_figure_by_color(screen: np.array):
5358
"""
5459
finds next piece class based on color
5560
shape recognition is difficult because the piece is not aligned with a grid
@@ -59,8 +64,8 @@ def get_figure_by_color(screen):
5964
for i in range(len(screen)):
6065
for j in range(len(screen[0])):
6166
pixel = screen[i, j][:3]
62-
for piece_idx in range(len(CONFIG['piece colors'])):
63-
distance = cmp_pixel(CONFIG['piece colors'][piece_idx][::-1], pixel)
67+
for piece_idx in range(len(piece_colors)):
68+
distance = cmp_pixel(piece_colors[piece_idx][::-1], pixel)
6469
if distance < 10:
6570
return piece_idx
6671
return -1

0 commit comments

Comments
 (0)