Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 127 additions & 79 deletions lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
import subprocess
import signal
import sys
import threading
import warnings
import selectors
import time
from typing import (
Any,
Expand Down Expand Up @@ -139,6 +139,35 @@ def dump_memory(base_addr, data, num_per_line, outfile):
outfile.write("\n")


def read_packet(
f: IO[bytes], trace_file: Optional[IO[str]] = None
) -> Optional[ProtocolMessage]:
"""Decode a JSON packet that starts with the content length and is
followed by the JSON bytes from a file 'f'. Returns None on EOF.
"""
line = f.readline().decode("utf-8")
if len(line) == 0:
return None # EOF.

# Watch for line that starts with the prefix
prefix = "Content-Length: "
if line.startswith(prefix):
# Decode length of JSON bytes
length = int(line[len(prefix) :])
# Skip empty line
separator = f.readline().decode()
if separator != "":
Exception("malformed DAP content header, unexpected line: " + separator)
# Read JSON bytes
json_str = f.read(length).decode()
if trace_file:
trace_file.write("from adapter:\n%s\n" % (json_str))
# Decode the JSON bytes into a python dictionary
return json.loads(json_str)

raise Exception("unexpected malformed message from lldb-dap: " + line)


def packet_type_is(packet, packet_type):
return "type" in packet and packet["type"] == packet_type

Expand Down Expand Up @@ -170,8 +199,16 @@ def __init__(
self.log_file = log_file
self.send = send
self.recv = recv
self.selector = selectors.DefaultSelector()
self.selector.register(recv, selectors.EVENT_READ)

# Packets that have been received and processed but have not yet been
# requested by a test case.
self._pending_packets: List[Optional[ProtocolMessage]] = []
# Received packets that have not yet been processed.
self._recv_packets: List[Optional[ProtocolMessage]] = []
# Used as a mutex for _recv_packets and for notify when _recv_packets
# changes.
self._recv_condition = threading.Condition()
self._recv_thread = threading.Thread(target=self._read_packet_thread)

# session state
self.init_commands = init_commands
Expand All @@ -197,6 +234,9 @@ def __init__(
# keyed by breakpoint id
self.resolved_breakpoints: dict[str, Breakpoint] = {}

# trigger enqueue thread
self._recv_thread.start()

@classmethod
def encode_content(cls, s: str) -> bytes:
return ("Content-Length: %u\r\n\r\n%s" % (len(s), s)).encode("utf-8")
Expand All @@ -212,46 +252,17 @@ def validate_response(cls, command, response):
f"seq mismatch in response {command['seq']} != {response['request_seq']}"
)

def _read_packet(
self,
timeout: float = DEFAULT_TIMEOUT,
) -> Optional[ProtocolMessage]:
"""Decode a JSON packet that starts with the content length and is
followed by the JSON bytes from self.recv. Returns None on EOF.
"""

ready = self.selector.select(timeout)
if not ready:
warnings.warn(
"timeout occurred waiting for a packet, check if the test has a"
" negative assertion and see if it can be inverted.",
stacklevel=4,
)
return None # timeout

line = self.recv.readline().decode("utf-8")
if len(line) == 0:
return None # EOF.

# Watch for line that starts with the prefix
prefix = "Content-Length: "
if line.startswith(prefix):
# Decode length of JSON bytes
length = int(line[len(prefix) :])
# Skip empty line
separator = self.recv.readline().decode()
if separator != "":
Exception("malformed DAP content header, unexpected line: " + separator)
# Read JSON bytes
json_str = self.recv.read(length).decode()
if self.trace_file:
self.trace_file.write(
"%s from adapter:\n%s\n" % (time.time(), json_str)
)
# Decode the JSON bytes into a python dictionary
return json.loads(json_str)

raise Exception("unexpected malformed message from lldb-dap: " + line)
def _read_packet_thread(self):
try:
while True:
packet = read_packet(self.recv, trace_file=self.trace_file)
# `packet` will be `None` on EOF. We want to pass it down to
# handle_recv_packet anyway so the main thread can handle unexpected
# termination of lldb-dap and stop waiting for new packets.
if not self._handle_recv_packet(packet):
break
finally:
dump_dap_log(self.log_file)

def get_modules(
self, start_module: Optional[int] = None, module_count: Optional[int] = None
Expand Down Expand Up @@ -299,6 +310,34 @@ def collect_output(
output += self.get_output(category, clear=clear)
return output

def _enqueue_recv_packet(self, packet: Optional[ProtocolMessage]):
with self.recv_condition:
self.recv_packets.append(packet)
self.recv_condition.notify()

def _handle_recv_packet(self, packet: Optional[ProtocolMessage]) -> bool:
"""Handles an incoming packet.

Called by the read thread that is waiting for all incoming packets
to store the incoming packet in "self._recv_packets" in a thread safe
way. This function will then signal the "self._recv_condition" to
indicate a new packet is available.

Args:
packet: A new packet to store.

Returns:
True if the caller should keep calling this function for more
packets.
"""
with self._recv_condition:
self._recv_packets.append(packet)
self._recv_condition.notify()
# packet is None on EOF
return packet is not None and not (
packet["type"] == "response" and packet["command"] == "disconnect"
)

def _recv_packet(
self,
*,
Expand All @@ -322,34 +361,46 @@ def _recv_packet(
The first matching packet for the given predicate, if specified,
otherwise None.
"""
deadline = time.time() + timeout

while time.time() < deadline:
packet = self._read_packet(timeout=deadline - time.time())
if packet is None:
return None
self._process_recv_packet(packet)
if not predicate or predicate(packet):
return packet

def _process_recv_packet(self, packet) -> None:
assert (
threading.current_thread != self._recv_thread
), "Must not be called from the _recv_thread"

def process_until_match():
self._process_recv_packets()
for i, packet in enumerate(self._pending_packets):
if packet is None:
# We need to return a truthy value to break out of the
# wait_for, use `EOFError` as an indicator of EOF.
return EOFError()
if predicate and predicate(packet):
self._pending_packets.pop(i)
return packet

with self._recv_condition:
packet = self._recv_condition.wait_for(process_until_match, timeout)
return None if isinstance(packet, EOFError) else packet

def _process_recv_packets(self) -> None:
"""Process received packets, updating the session state."""
if packet and ("seq" not in packet or packet["seq"] == 0):
warnings.warn(
f"received a malformed packet, expected 'seq != 0' for {packet!r}"
)
# Handle events that may modify any stateful properties of
# the DAP session.
if packet and packet["type"] == "event":
self._handle_event(packet)
elif packet and packet["type"] == "request":
# Handle reverse requests and keep processing.
self._handle_reverse_request(packet)
with self._recv_condition:
for packet in self._recv_packets:
if packet and ("seq" not in packet or packet["seq"] == 0):
warnings.warn(
f"received a malformed packet, expected 'seq != 0' for {packet!r}"
)
# Handle events that may modify any stateful properties of
# the DAP session.
if packet and packet["type"] == "event":
self._handle_event(packet)
elif packet and packet["type"] == "request":
# Handle reverse requests and keep processing.
self._handle_reverse_request(packet)
# Move the packet to the pending queue.
self._pending_packets.append(packet)
self._recv_packets.clear()

def _handle_event(self, packet: Event) -> None:
"""Handle any events that modify debug session state we track."""
self.events.append(packet)

event = packet["event"]
body: Optional[Dict] = packet.get("body", None)

Expand Down Expand Up @@ -402,8 +453,6 @@ def _handle_event(self, packet: Event) -> None:
self.invalidated_event = packet
elif event == "memory":
self.memory_event = packet
elif event == "module":
self.module_events.append(packet)

def _handle_reverse_request(self, request: Request) -> None:
if request in self.reverse_requests:
Expand Down Expand Up @@ -472,14 +521,18 @@ def send_packet(self, packet: ProtocolMessage) -> int:

Returns the seq number of the request.
"""
packet["seq"] = self.sequence
self.sequence += 1
# Set the seq for requests.
if packet["type"] == "request":
packet["seq"] = self.sequence
self.sequence += 1
else:
packet["seq"] = 0

# Encode our command dictionary as a JSON string
json_str = json.dumps(packet, separators=(",", ":"))

if self.trace_file:
self.trace_file.write("%s to adapter:\n%s\n" % (time.time(), json_str))
self.trace_file.write("to adapter:\n%s\n" % (json_str))

length = len(json_str)
if length > 0:
Expand Down Expand Up @@ -860,8 +913,6 @@ def request_restart(self, restartArguments=None):
if restartArguments:
command_dict["arguments"] = restartArguments

# Clear state, the process is about to restart...
self._process_continued(True)
response = self._send_recv(command_dict)
# Caller must still call wait_for_stopped.
return response
Expand Down Expand Up @@ -1428,10 +1479,8 @@ def request_testGetTargetBreakpoints(self):

def terminate(self):
self.send.close()
self.recv.close()
self.selector.close()
if self.log_file:
dump_dap_log(self.log_file)
if self._recv_thread.is_alive():
self._recv_thread.join()

def request_setInstructionBreakpoints(self, memory_reference=[]):
breakpoints = []
Expand Down Expand Up @@ -1528,7 +1577,6 @@ def launch(
stdout=subprocess.PIPE,
stderr=sys.stderr,
env=adapter_env,
bufsize=0,
)

if connection is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
# DAP tests as a whole have been flakey on the Windows on Arm bot. See:
# https://github.com/llvm/llvm-project/issues/137660
@skipIf(oslist=["windows"], archs=["aarch64"])
# The Arm Linux bot needs stable resources before it can run these tests reliably.
@skipIf(oslist=["linux"], archs=["arm$"])
class DAPTestCaseBase(TestBase):
# set timeout based on whether ASAN was enabled or not. Increase
# timeout by a factor of 10 if ASAN is enabled.
Expand Down Expand Up @@ -418,7 +416,7 @@ def continue_to_next_stop(self):
return self.dap_server.wait_for_stopped()

def continue_to_breakpoint(self, breakpoint_id: str):
self.continue_to_breakpoints([breakpoint_id])
self.continue_to_breakpoints((breakpoint_id))

def continue_to_breakpoints(self, breakpoint_ids):
self.do_continue()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,20 +81,24 @@ def test_breakpoint_events(self):
breakpoint["verified"], "expect foo breakpoint to not be verified"
)

# Flush the breakpoint events.
self.dap_server.wait_for_breakpoint_events()

# Continue to the breakpoint
self.continue_to_breakpoint(foo_bp_id)
self.continue_to_next_stop() # foo_bp2
self.continue_to_breakpoint(main_bp_id)
self.continue_to_exit()
self.continue_to_breakpoints(dap_breakpoint_ids)

bp_events = [e for e in self.dap_server.events if e["event"] == "breakpoint"]
verified_breakpoint_ids = []
unverified_breakpoint_ids = []
for breakpoint_event in self.dap_server.wait_for_breakpoint_events():
breakpoint = breakpoint_event["body"]["breakpoint"]
id = breakpoint["id"]
if breakpoint["verified"]:
verified_breakpoint_ids.append(id)
else:
unverified_breakpoint_ids.append(id)

main_bp_events = [
e for e in bp_events if e["body"]["breakpoint"]["id"] == main_bp_id
]
foo_bp_events = [
e for e in bp_events if e["body"]["breakpoint"]["id"] == foo_bp_id
]
self.assertIn(main_bp_id, unverified_breakpoint_ids)
self.assertIn(foo_bp_id, unverified_breakpoint_ids)

self.assertTrue(main_bp_events)
self.assertTrue(foo_bp_events)
self.assertIn(main_bp_id, verified_breakpoint_ids)
self.assertIn(foo_bp_id, verified_breakpoint_ids)
2 changes: 1 addition & 1 deletion lldb/test/API/tools/lldb-dap/launch/TestDAP_launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def test_debuggerRoot(self):
self.build_and_launch(
program, debuggerRoot=program_parent_dir, initCommands=commands
)
self.continue_to_exit()
output = self.get_console()
self.assertTrue(output and len(output) > 0, "expect console output")
lines = output.splitlines()
Expand All @@ -172,6 +171,7 @@ def test_debuggerRoot(self):
% (program_parent_dir, line[len(prefix) :]),
)
self.assertTrue(found, "verified lldb-dap working directory")
self.continue_to_exit()

def test_sourcePath(self):
"""
Expand Down
Loading
Loading