Skip to content

Commit 06c05a2

Browse files
committed
add thread safety lock and tests
1 parent 086bc51 commit 06c05a2

File tree

2 files changed

+218
-39
lines changed

2 files changed

+218
-39
lines changed

py/selenium/webdriver/common/bidi/browsing_context.py

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
import threading
1819
from dataclasses import dataclass
1920
from typing import Any, Callable, Optional, Union
2021

@@ -538,6 +539,8 @@ def __init__(self, conn, event_configs: dict[str, EventConfig]):
538539
self.subscriptions: dict = {}
539540
self._bidi_to_class = {config.bidi_event: config.event_class for config in event_configs.values()}
540541
self._available_events = ", ".join(sorted(event_configs.keys()))
542+
# Thread safety lock for subscription operations
543+
self._subscription_lock = threading.Lock()
541544

542545
def validate_event(self, event: str) -> EventConfig:
543546
event_config = self.event_configs.get(event)
@@ -553,10 +556,11 @@ def subscribe_to_event(self, bidi_event: str, contexts: Optional[list[str]] = No
553556
bidi_event: The BiDi event name.
554557
contexts: Optional browsing context IDs to subscribe to.
555558
"""
556-
if bidi_event not in self.subscriptions:
557-
session = Session(self.conn)
558-
self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts))
559-
self.subscriptions[bidi_event] = []
559+
with self._subscription_lock:
560+
if bidi_event not in self.subscriptions:
561+
session = Session(self.conn)
562+
self.conn.execute(session.subscribe(bidi_event, browsing_contexts=contexts))
563+
self.subscriptions[bidi_event] = []
560564

561565
def unsubscribe_from_event(self, bidi_event: str) -> None:
562566
"""Unsubscribe from a BiDi event if no more callbacks exist.
@@ -565,19 +569,22 @@ def unsubscribe_from_event(self, bidi_event: str) -> None:
565569
----------
566570
bidi_event: The BiDi event name.
567571
"""
568-
callback_list = self.subscriptions.get(bidi_event)
569-
if callback_list is not None and not callback_list:
570-
session = Session(self.conn)
571-
self.conn.execute(session.unsubscribe(bidi_event))
572-
del self.subscriptions[bidi_event]
572+
with self._subscription_lock:
573+
callback_list = self.subscriptions.get(bidi_event)
574+
if callback_list is not None and not callback_list:
575+
session = Session(self.conn)
576+
self.conn.execute(session.unsubscribe(bidi_event))
577+
del self.subscriptions[bidi_event]
573578

574579
def add_callback_to_tracking(self, bidi_event: str, callback_id: int) -> None:
575-
self.subscriptions[bidi_event].append(callback_id)
580+
with self._subscription_lock:
581+
self.subscriptions[bidi_event].append(callback_id)
576582

577583
def remove_callback_from_tracking(self, bidi_event: str, callback_id: int) -> None:
578-
callback_list = self.subscriptions.get(bidi_event)
579-
if callback_list and callback_id in callback_list:
580-
callback_list.remove(callback_id)
584+
with self._subscription_lock:
585+
callback_list = self.subscriptions.get(bidi_event)
586+
if callback_list and callback_id in callback_list:
587+
callback_list.remove(callback_id)
581588

582589
def add_event_handler(self, event: str, callback: Callable, contexts: Optional[list[str]] = None) -> int:
583590
event_config = self.validate_event(event)
@@ -606,21 +613,22 @@ def remove_event_handler(self, event: str, callback_id: int) -> None:
606613

607614
def clear_event_handlers(self) -> None:
608615
"""Clear all event handlers from the browsing context."""
609-
if not self.subscriptions:
610-
return
616+
with self._subscription_lock:
617+
if not self.subscriptions:
618+
return
611619

612-
session = Session(self.conn)
620+
session = Session(self.conn)
613621

614-
for bidi_event, callback_ids in list(self.subscriptions.items()):
615-
event_class = self._bidi_to_class.get(bidi_event)
616-
if event_class:
617-
# Remove all callbacks for this event
618-
for callback_id in callback_ids:
619-
self.conn.remove_callback(event_class, callback_id)
622+
for bidi_event, callback_ids in list(self.subscriptions.items()):
623+
event_class = self._bidi_to_class.get(bidi_event)
624+
if event_class:
625+
# Remove all callbacks for this event
626+
for callback_id in callback_ids:
627+
self.conn.remove_callback(event_class, callback_id)
620628

621-
self.conn.execute(session.unsubscribe(bidi_event))
629+
self.conn.execute(session.unsubscribe(bidi_event))
622630

623-
self.subscriptions.clear()
631+
self.subscriptions.clear()
624632

625633

626634
class BrowsingContext:

py/test/selenium/webdriver/common/bidi_browsing_context_tests.py

Lines changed: 186 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -931,25 +931,196 @@ def on_context_created_2(info):
931931

932932

933933
def test_event_handler_thread_safety(driver):
934-
"""Test event handlers are thread-safe."""
934+
"""Test thread safety with multiple non-atomic operations in callbacks."""
935+
import concurrent.futures
936+
import time
937+
935938
events_received = []
936-
event_lock = threading.Lock()
939+
context_counts = {}
940+
event_type_counts = {}
941+
processing_times = []
942+
consistency_errors = []
943+
thread_errors = []
944+
945+
data_lock = threading.Lock()
946+
callback_ids = []
947+
registration_complete = threading.Event()
948+
949+
def complex_event_callback(info):
950+
"""Callback with multiple non-atomic operations that require thread synchronization."""
951+
start_time = time.time()
952+
time.sleep(0.02) # Create race condition window
953+
954+
with data_lock:
955+
# Multiple operations that could race without proper locking
956+
initial_event_count = len(events_received)
957+
_ = sum(context_counts.values()) if context_counts else 0
958+
_ = sum(event_type_counts.values()) if event_type_counts else 0
937959

938-
def on_context_created(info):
939-
with event_lock:
940960
events_received.append(info)
941961

942-
callback_id = driver.browsing_context.add_event_handler("context_created", on_context_created)
962+
context_id = info.context
963+
if context_id not in context_counts:
964+
context_counts[context_id] = 0
965+
context_counts[context_id] += 1
966+
967+
event_type = info.__class__.__name__
968+
if event_type not in event_type_counts:
969+
event_type_counts[event_type] = 0
970+
event_type_counts[event_type] += 1
971+
972+
processing_time = time.time() - start_time
973+
processing_times.append(processing_time)
974+
975+
# Verify data consistency
976+
final_event_count = len(events_received)
977+
final_context_total = sum(context_counts.values())
978+
final_type_total = sum(event_type_counts.values())
979+
final_processing_count = len(processing_times)
980+
981+
expected_count = initial_event_count + 1
982+
if not (
983+
final_event_count == final_context_total == final_type_total == final_processing_count == expected_count
984+
):
985+
error_msg = (
986+
f"Data consistency error! Events: {final_event_count}, "
987+
f"Contexts: {final_context_total}, Types: {final_type_total}, "
988+
f"Times: {final_processing_count}, Expected: {expected_count}"
989+
)
990+
consistency_errors.append(error_msg)
991+
992+
def register_handler(thread_id):
993+
try:
994+
callback_id = driver.browsing_context.add_event_handler("context_created", complex_event_callback)
995+
with data_lock:
996+
callback_ids.append(callback_id)
997+
if len(callback_ids) == 5:
998+
registration_complete.set()
999+
return callback_id
1000+
except Exception as e:
1001+
with data_lock:
1002+
thread_errors.append(f"Thread {thread_id}: Registration failed: {e}")
1003+
return None
1004+
1005+
def remove_handler(callback_id, thread_id):
1006+
try:
1007+
driver.browsing_context.remove_event_handler("context_created", callback_id)
1008+
except Exception as e:
1009+
with data_lock:
1010+
thread_errors.append(f"Thread {thread_id}: Removal failed: {e}")
1011+
1012+
initial_context = driver.browsing_context.create(type=WindowTypes.TAB)
1013+
1014+
# Concurrent registration
1015+
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
1016+
futures = {}
1017+
for i in range(5):
1018+
future = executor.submit(register_handler, f"reg-{i}")
1019+
futures[future] = f"reg-{i}"
1020+
1021+
for future in futures:
1022+
thread_id = futures[future]
1023+
try:
1024+
future.result(timeout=15)
1025+
except concurrent.futures.TimeoutError:
1026+
with data_lock:
1027+
thread_errors.append(f"Thread {thread_id}: Registration timed out")
1028+
except Exception as e:
1029+
with data_lock:
1030+
thread_errors.append(f"Thread {thread_id}: Registration exception: {e}")
1031+
1032+
registration_complete.wait(timeout=5)
1033+
1034+
with data_lock:
1035+
successful_registrations = len(callback_ids)
1036+
1037+
# Trigger events while handlers are active
1038+
if successful_registrations > 0:
1039+
test_contexts = []
1040+
for i in range(3):
1041+
try:
1042+
context = driver.browsing_context.create(type=WindowTypes.TAB)
1043+
test_contexts.append(context)
1044+
time.sleep(0.1)
1045+
except Exception as e:
1046+
thread_errors.append(f"Failed to create test context {i}: {e}")
1047+
1048+
time.sleep(1.0) # Allow event processing
1049+
1050+
for context in test_contexts:
1051+
try:
1052+
driver.browsing_context.close(context)
1053+
except Exception:
1054+
pass
1055+
1056+
# Concurrent removal
1057+
if callback_ids:
1058+
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
1059+
futures = {}
1060+
for i, callback_id in enumerate(callback_ids):
1061+
future = executor.submit(remove_handler, callback_id, f"rem-{i}")
1062+
futures[future] = f"rem-{i}"
1063+
1064+
for future in futures:
1065+
thread_id = futures[future]
1066+
try:
1067+
future.result(timeout=15)
1068+
except concurrent.futures.TimeoutError:
1069+
with data_lock:
1070+
thread_errors.append(f"Thread {thread_id}: Removal timed out")
1071+
except Exception as e:
1072+
with data_lock:
1073+
thread_errors.append(f"Thread {thread_id}: Removal exception: {e}")
1074+
1075+
time.sleep(0.5)
1076+
1077+
# Verify handlers are removed
1078+
with data_lock:
1079+
events_before_removal_test = len(events_received)
1080+
1081+
try:
1082+
post_removal_context = driver.browsing_context.create(type=WindowTypes.TAB)
1083+
time.sleep(0.8)
1084+
driver.browsing_context.close(post_removal_context)
1085+
except Exception as e:
1086+
thread_errors.append(f"Failed to create post-removal test context: {e}")
1087+
1088+
with data_lock:
1089+
events_after_removal = len(events_received) - events_before_removal_test
9431090

944-
# Create multiple contexts in rapid succession
945-
context_ids = []
946-
for i in range(3):
947-
context_id = driver.browsing_context.create(type=WindowTypes.TAB)
948-
context_ids.append(context_id)
1091+
# Cleanup
1092+
try:
1093+
driver.browsing_context.close(initial_context)
1094+
except Exception as e:
1095+
thread_errors.append(f"Cleanup error: {e}")
9491096

950-
# Verify all events were received (might be 1 more than 3 due to default context)
951-
assert len(events_received) >= 3
1097+
# Assertions
1098+
all_errors = thread_errors + consistency_errors
1099+
if all_errors:
1100+
pytest.fail("Thread safety test failed with errors:\n" + "\n".join(all_errors))
9521101

953-
for context_id in context_ids:
954-
driver.browsing_context.close(context_id)
955-
driver.browsing_context.remove_event_handler("context_created", callback_id)
1102+
assert successful_registrations > 0, f"No handlers were successfully registered (got {successful_registrations})"
1103+
assert len(events_received) > 0, "No events were received during test"
1104+
1105+
# Verify data consistency across multiple counters
1106+
with data_lock:
1107+
total_context_events = sum(context_counts.values()) if context_counts else 0
1108+
total_type_events = sum(event_type_counts.values()) if event_type_counts else 0
1109+
1110+
assert len(events_received) == total_context_events, (
1111+
f"Context count mismatch: {len(events_received)} vs {total_context_events}"
1112+
)
1113+
assert len(events_received) == total_type_events, (
1114+
f"Type count mismatch: {len(events_received)} vs {total_type_events}"
1115+
)
1116+
assert len(events_received) == len(processing_times), (
1117+
f"Processing time count mismatch: {len(events_received)} vs {len(processing_times)}"
1118+
)
1119+
1120+
# Verify handlers were properly removed
1121+
assert events_after_removal == 0, f"Handlers still active after removal! Got {events_after_removal} events"
1122+
1123+
# Verify event object
1124+
for i, event in enumerate(events_received):
1125+
assert hasattr(event, "context"), f"Event {i} missing 'context' attribute"
1126+
assert isinstance(event.context, str), f"Event {i} 'context' is not string: {type(event.context)}"

0 commit comments

Comments
 (0)