Skip to content

Commit e3e9652

Browse files
authored
Windows: refactored how internal handles are stored (#250)
1 parent ba25420 commit e3e9652

File tree

3 files changed

+87
-72
lines changed

3 files changed

+87
-72
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ See Git checking messages for full history.
44

55
## 8.0.4 (2023/xx/xx)
66
- Linux: add failure handling to `XOpenDisplay()` call (#247)
7+
- Windows: refactored how internal handles are stored
8+
- Windows: removed side effects when leaving the context manager, resources are all freed
79
- CI: run tests via xvfb-run on GitHub Actions (#248)
810
- tests: Use `PyVirtualDisplay` instead of `xvfbwrapper` (#249)
9-
- :heart: contributors: @mgorny
11+
- :heart: contributors: @mgorny, @CTPaHHuK-HEbA
1012

1113
## 8.0.3 (2023/04/15)
1214
- added support for Python 3.12

src/mss/windows.py

Lines changed: 43 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
"""
55
import ctypes
66
import sys
7-
import threading
8-
from ctypes import POINTER, WINFUNCTYPE, Structure, c_void_p
7+
from ctypes import POINTER, WINFUNCTYPE, Structure, c_int, c_void_p
98
from ctypes.wintypes import (
109
BOOL,
1110
DOUBLE,
@@ -22,7 +21,8 @@
2221
UINT,
2322
WORD,
2423
)
25-
from typing import Any, Dict, Optional
24+
from threading import local
25+
from typing import Any, Optional
2626

2727
from .base import MSSBase
2828
from .exception import ScreenShotError
@@ -78,27 +78,22 @@ class BITMAPINFO(Structure):
7878
"BitBlt": ("gdi32", [HDC, INT, INT, INT, INT, HDC, INT, INT, DWORD], BOOL),
7979
"CreateCompatibleBitmap": ("gdi32", [HDC, INT, INT], HBITMAP),
8080
"CreateCompatibleDC": ("gdi32", [HDC], HDC),
81+
"DeleteDC": ("gdi32", [HDC], HDC),
8182
"DeleteObject": ("gdi32", [HGDIOBJ], INT),
8283
"EnumDisplayMonitors": ("user32", [HDC, c_void_p, MONITORNUMPROC, LPARAM], BOOL),
8384
"GetDeviceCaps": ("gdi32", [HWND, INT], INT),
8485
"GetDIBits": ("gdi32", [HDC, HBITMAP, UINT, UINT, c_void_p, POINTER(BITMAPINFO), UINT], BOOL),
8586
"GetSystemMetrics": ("user32", [INT], INT),
8687
"GetWindowDC": ("user32", [HWND], HDC),
88+
"ReleaseDC": ("user32", [HWND, HDC], c_int),
8789
"SelectObject": ("gdi32", [HDC, HGDIOBJ], HGDIOBJ),
8890
}
8991

9092

9193
class MSS(MSSBase):
9294
"""Multiple ScreenShots implementation for Microsoft Windows."""
9395

94-
__slots__ = {"_bbox", "_bmi", "_data", "gdi32", "user32"}
95-
96-
# Class attributes instanced one time to prevent resource leaks.
97-
bmp = None
98-
memdc = None
99-
100-
# A dict to maintain *srcdc* values created by multiple threads.
101-
_srcdc_dict: Dict[threading.Thread, int] = {}
96+
__slots__ = {"gdi32", "user32", "_handles"}
10297

10398
def __init__(self, /, **kwargs: Any) -> None:
10499
"""Windows initialisations."""
@@ -110,12 +105,12 @@ def __init__(self, /, **kwargs: Any) -> None:
110105
self._set_cfunctions()
111106
self._set_dpi_awareness()
112107

113-
self._bbox = {"height": 0, "width": 0}
114-
self._data: ctypes.Array[ctypes.c_char] = ctypes.create_string_buffer(0)
115-
116-
srcdc = self._get_srcdc()
117-
if not MSS.memdc:
118-
MSS.memdc = self.gdi32.CreateCompatibleDC(srcdc)
108+
# Available thread-specific variables
109+
self._handles = local()
110+
self._handles.region_height_width = (0, 0)
111+
self._handles.bmp = None
112+
self._handles.srcdc = self.user32.GetWindowDC(0)
113+
self._handles.memdc = self.gdi32.CreateCompatibleDC(self._handles.srcdc)
119114

120115
bmi = BITMAPINFO()
121116
bmi.bmiHeader.biSize = ctypes.sizeof(BITMAPINFOHEADER)
@@ -124,7 +119,21 @@ def __init__(self, /, **kwargs: Any) -> None:
124119
bmi.bmiHeader.biCompression = 0 # 0 = BI_RGB (no compression)
125120
bmi.bmiHeader.biClrUsed = 0 # See grab.__doc__ [3]
126121
bmi.bmiHeader.biClrImportant = 0 # See grab.__doc__ [3]
127-
self._bmi = bmi
122+
self._handles.bmi = bmi
123+
124+
def close(self) -> None:
125+
# Clean-up
126+
if self._handles.bmp:
127+
self.gdi32.DeleteObject(self._handles.bmp)
128+
self._handles.bmp = None
129+
130+
if self._handles.memdc:
131+
self.gdi32.DeleteDC(self._handles.memdc)
132+
self._handles.memdc = None
133+
134+
if self._handles.srcdc:
135+
self.user32.ReleaseDC(0, self._handles.srcdc)
136+
self._handles.srcdc = None
128137

129138
def _set_cfunctions(self) -> None:
130139
"""Set all ctypes functions and attach them to attributes."""
@@ -149,26 +158,9 @@ def _set_dpi_awareness(self) -> None:
149158
# These applications are not automatically scaled by the system.
150159
ctypes.windll.shcore.SetProcessDpiAwareness(2)
151160
elif (6, 0) <= version < (6, 3):
152-
# Windows Vista, 7, 8 and Server 2012
161+
# Windows Vista, 7, 8, and Server 2012
153162
self.user32.SetProcessDPIAware()
154163

155-
def _get_srcdc(self) -> int:
156-
"""
157-
Retrieve a thread-safe HDC from GetWindowDC().
158-
In multithreading, if the thread that creates *srcdc* is dead, *srcdc* will
159-
no longer be valid to grab the screen. The *srcdc* attribute is replaced
160-
with *_srcdc_dict* to maintain the *srcdc* values in multithreading.
161-
Since the current thread and main thread are always alive, reuse their *srcdc* value first.
162-
"""
163-
cur_thread, main_thread = threading.current_thread(), threading.main_thread()
164-
current_srcdc = MSS._srcdc_dict.get(cur_thread) or MSS._srcdc_dict.get(main_thread)
165-
if current_srcdc:
166-
srcdc = current_srcdc
167-
else:
168-
srcdc = self.user32.GetWindowDC(0)
169-
MSS._srcdc_dict[cur_thread] = srcdc
170-
return srcdc
171-
172164
def _monitors_impl(self) -> None:
173165
"""Get positions of monitors. It will populate self._monitors."""
174166

@@ -240,35 +232,26 @@ def _grab_impl(self, monitor: Monitor, /) -> ScreenShot:
240232
Thanks to http://stackoverflow.com/a/3688682
241233
"""
242234

243-
srcdc, memdc = self._get_srcdc(), MSS.memdc
235+
srcdc, memdc = self._handles.srcdc, self._handles.memdc
236+
gdi = self.gdi32
244237
width, height = monitor["width"], monitor["height"]
245238

246-
if (self._bbox["height"], self._bbox["width"]) != (height, width):
247-
self._bbox = monitor
248-
self._bmi.bmiHeader.biWidth = width
249-
self._bmi.bmiHeader.biHeight = -height # Why minus? [1]
250-
self._data = ctypes.create_string_buffer(width * height * 4) # [2]
251-
if MSS.bmp:
252-
self.gdi32.DeleteObject(MSS.bmp)
253-
MSS.bmp = self.gdi32.CreateCompatibleBitmap(srcdc, width, height)
254-
self.gdi32.SelectObject(memdc, MSS.bmp)
255-
256-
self.gdi32.BitBlt(
257-
memdc,
258-
0,
259-
0,
260-
width,
261-
height,
262-
srcdc,
263-
monitor["left"],
264-
monitor["top"],
265-
SRCCOPY | CAPTUREBLT,
266-
)
267-
bits = self.gdi32.GetDIBits(memdc, MSS.bmp, 0, height, self._data, self._bmi, DIB_RGB_COLORS)
239+
if self._handles.region_height_width != (height, width):
240+
self._handles.region_height_width = (height, width)
241+
self._handles.bmi.bmiHeader.biWidth = width
242+
self._handles.bmi.bmiHeader.biHeight = -height # Why minus? [1]
243+
self._handles.data = ctypes.create_string_buffer(width * height * 4) # [2]
244+
if self._handles.bmp:
245+
gdi.DeleteObject(self._handles.bmp)
246+
self._handles.bmp = gdi.CreateCompatibleBitmap(srcdc, width, height)
247+
gdi.SelectObject(memdc, self._handles.bmp)
248+
249+
gdi.BitBlt(memdc, 0, 0, width, height, srcdc, monitor["left"], monitor["top"], SRCCOPY | CAPTUREBLT)
250+
bits = gdi.GetDIBits(memdc, self._handles.bmp, 0, height, self._handles.data, self._handles.bmi, DIB_RGB_COLORS)
268251
if bits != height:
269252
raise ScreenShotError("gdi32.GetDIBits() failed.")
270253

271-
return self.cls_image(bytearray(self._data), monitor)
254+
return self.cls_image(bytearray(self._handles.data), monitor)
272255

273256
def _cursor_impl(self) -> Optional[ScreenShot]:
274257
"""Retrieve all cursor data. Pixels have to be RGB."""

src/tests/test_windows.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,34 +24,46 @@ def test_implementation(monkeypatch):
2424

2525
def test_region_caching():
2626
"""The region to grab is cached, ensure this is well-done."""
27-
from mss.windows import MSS
28-
2927
with mss.mss() as sct:
30-
# Reset the current BMP
31-
if MSS.bmp:
32-
sct.gdi32.DeleteObject(MSS.bmp)
33-
MSS.bmp = None
34-
3528
# Grab the area 1
3629
region1 = {"top": 0, "left": 0, "width": 200, "height": 200}
3730
sct.grab(region1)
38-
bmp1 = id(MSS.bmp)
31+
bmp1 = id(sct._handles.bmp)
3932

4033
# Grab the area 2, the cached BMP is used
4134
# Same sizes but different positions
4235
region2 = {"top": 200, "left": 200, "width": 200, "height": 200}
4336
sct.grab(region2)
44-
bmp2 = id(MSS.bmp)
37+
bmp2 = id(sct._handles.bmp)
4538
assert bmp1 == bmp2
4639

4740
# Grab the area 2 again, the cached BMP is used
4841
sct.grab(region2)
49-
assert bmp2 == id(MSS.bmp)
42+
assert bmp2 == id(sct._handles.bmp)
43+
44+
45+
def test_region_not_caching():
46+
"""The region to grab is not bad cached previous grab."""
47+
grab1 = mss.mss()
48+
grab2 = mss.mss()
49+
50+
region1 = {"top": 0, "left": 0, "width": 100, "height": 100}
51+
region2 = {"top": 0, "left": 0, "width": 50, "height": 1}
52+
grab1.grab(region1)
53+
bmp1 = id(grab1._handles.bmp)
54+
grab2.grab(region2)
55+
bmp2 = id(grab2._handles.bmp)
56+
assert bmp1 != bmp2
57+
58+
# Grab the area 1, is not bad cached BMP previous grab the area 2
59+
grab1.grab(region1)
60+
bmp1 = id(grab1._handles.bmp)
61+
assert bmp1 != bmp2
5062

5163

5264
def run_child_thread(loops):
5365
for _ in range(loops):
54-
with mss.mss() as sct:
66+
with mss.mss() as sct: # New sct for every loop
5567
sct.grab(sct.monitors[1])
5668

5769

@@ -66,3 +78,21 @@ def test_thread_safety():
6678
thread2.start()
6779
thread1.join()
6880
thread2.join()
81+
82+
83+
def run_child_thread_bbox(loops, bbox):
84+
with mss.mss() as sct: # One sct for all loops
85+
for _ in range(loops):
86+
sct.grab(bbox)
87+
88+
89+
def test_thread_safety_regions():
90+
"""Thread safety test for different regions
91+
The following code will throw a ScreenShotError exception if thread-safety is not guaranted.
92+
"""
93+
thread1 = threading.Thread(target=run_child_thread_bbox, args=(100, (0, 0, 100, 100)))
94+
thread2 = threading.Thread(target=run_child_thread_bbox, args=(100, (0, 0, 50, 1)))
95+
thread1.start()
96+
thread2.start()
97+
thread1.join()
98+
thread2.join()

0 commit comments

Comments
 (0)