Skip to content

Commit fba6a54

Browse files
committed
Adjusting callsites of wait_for_event and making Source into a typed dict instead of its own class.
1 parent e692c5f commit fba6a54

File tree

9 files changed

+106
-102
lines changed

9 files changed

+106
-102
lines changed

lldb/packages/Python/lldbsuite/test/tools/lldb-dap/dap_server.py

Lines changed: 65 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,44 @@ class Response(TypedDict):
5757
ProtocolMessage = Union[Event, Request, Response]
5858

5959

60+
class Source(TypedDict, total=False):
61+
name: str
62+
path: str
63+
sourceReference: int
64+
65+
@staticmethod
66+
def build(
67+
*,
68+
name: Optional[str] = None,
69+
path: Optional[str] = None,
70+
source_reference: Optional[int] = None,
71+
) -> "Source":
72+
"""Builds a source from the given name, path or source_reference."""
73+
if not name and not path and not source_reference:
74+
raise ValueError(
75+
"Source.build requires either name, path, or source_reference"
76+
)
77+
78+
s = Source()
79+
if name:
80+
s["name"] = name
81+
if path:
82+
if not name:
83+
s["name"] = os.path.basename(path)
84+
s["path"] = path
85+
if source_reference is not None:
86+
s["sourceReference"] = source_reference
87+
return s
88+
89+
6090
class Breakpoint(TypedDict, total=False):
6191
id: int
6292
verified: bool
93+
source: Source
94+
95+
@staticmethod
96+
def is_verified(src: "Breakpoint") -> bool:
97+
return src.get("verified", False)
6398

6499

65100
def dump_memory(base_addr, data, num_per_line, outfile):
@@ -142,58 +177,6 @@ def dump_dap_log(log_file: Optional[str]) -> None:
142177
print("========= END =========", file=sys.stderr)
143178

144179

145-
class Source(object):
146-
def __init__(
147-
self,
148-
path: Optional[str] = None,
149-
source_reference: Optional[int] = None,
150-
raw_dict: Optional[dict[str, Any]] = None,
151-
):
152-
self._name = None
153-
self._path = None
154-
self._source_reference = None
155-
self._raw_dict = None
156-
157-
if path is not None:
158-
self._name = os.path.basename(path)
159-
self._path = path
160-
elif source_reference is not None:
161-
self._source_reference = source_reference
162-
elif raw_dict is not None:
163-
self._raw_dict = raw_dict
164-
else:
165-
raise ValueError("Either path or source_reference must be provided")
166-
167-
def __str__(self):
168-
return f"Source(name={self.name}, path={self.path}), source_reference={self.source_reference})"
169-
170-
def as_dict(self):
171-
if self._raw_dict is not None:
172-
return self._raw_dict
173-
174-
source_dict = {}
175-
if self._name is not None:
176-
source_dict["name"] = self._name
177-
if self._path is not None:
178-
source_dict["path"] = self._path
179-
if self._source_reference is not None:
180-
source_dict["sourceReference"] = self._source_reference
181-
return source_dict
182-
183-
184-
class Breakpoint(object):
185-
def __init__(self, obj):
186-
self._breakpoint = obj
187-
188-
def is_verified(self):
189-
"""Check if the breakpoint is verified."""
190-
return self._breakpoint.get("verified", False)
191-
192-
def source(self):
193-
"""Get the source of the breakpoint."""
194-
return self._breakpoint.get("source", {})
195-
196-
197180
class NotSupportedError(KeyError):
198181
"""Raised if a feature is not supported due to its capabilities."""
199182

@@ -225,7 +208,7 @@ def __init__(
225208
# session state
226209
self.init_commands = init_commands
227210
self.exit_status: Optional[int] = None
228-
self.capabilities: Optional[Dict] = None
211+
self.capabilities: Dict = {}
229212
self.initialized: bool = False
230213
self.configuration_done_sent: bool = False
231214
self.process_event_body: Optional[Dict] = None
@@ -455,8 +438,6 @@ def _handle_event(self, packet: Event) -> None:
455438
# Breakpoint events are sent when a breakpoint is resolved
456439
self._update_verified_breakpoints([body["breakpoint"]])
457440
elif event == "capabilities" and body:
458-
if self.capabilities is None:
459-
self.capabilities = {}
460441
# Update the capabilities with new ones from the event.
461442
self.capabilities.update(body["capabilities"])
462443

@@ -467,13 +448,14 @@ def _handle_reverse_request(self, request: Request) -> None:
467448
arguments = request.get("arguments")
468449
if request["command"] == "runInTerminal" and arguments is not None:
469450
in_shell = arguments.get("argsCanBeInterpretedByShell", False)
451+
print("spawning...", arguments["args"])
470452
proc = subprocess.Popen(
471453
arguments["args"],
472454
env=arguments.get("env", {}),
473455
cwd=arguments.get("cwd", None),
474456
stdin=subprocess.DEVNULL,
475-
stdout=subprocess.DEVNULL,
476-
stderr=subprocess.DEVNULL,
457+
stdout=sys.stderr,
458+
stderr=sys.stderr,
477459
shell=in_shell,
478460
)
479461
body = {}
@@ -488,7 +470,6 @@ def _handle_reverse_request(self, request: Request) -> None:
488470
"request_seq": request["seq"],
489471
"success": True,
490472
"command": "runInTerminal",
491-
"message": None,
492473
"body": body,
493474
}
494475
)
@@ -520,9 +501,7 @@ def _update_verified_breakpoints(self, breakpoints: list[Breakpoint]):
520501
if "id" not in bp:
521502
continue
522503

523-
self.resolved_breakpoints[str(breakpoint["id"])] = Breakpoint(
524-
breakpoint
525-
)
504+
self.resolved_breakpoints[str(bp["id"])] = bp
526505

527506
def send_packet(self, packet: ProtocolMessage) -> int:
528507
"""Takes a dictionary representation of a DAP request and send the request to the debug adapter.
@@ -563,7 +542,7 @@ def _send_recv(self, request: Request) -> Optional[Response]:
563542
return response
564543

565544
def receive_response(self, seq: int) -> Optional[Response]:
566-
"""Waits for the a response with the associated request_sec."""
545+
"""Waits for a response with the associated request_sec."""
567546

568547
def predicate(p: ProtocolMessage):
569548
return p["type"] == "response" and p["request_seq"] == seq
@@ -605,7 +584,7 @@ def wait_for_stopped(
605584
def wait_for_breakpoint_events(self, timeout: Optional[float] = None):
606585
breakpoint_events: list[Event] = []
607586
while True:
608-
event = self.wait_for_event("breakpoint", timeout=timeout)
587+
event = self.wait_for_event(["breakpoint"], timeout=timeout)
609588
if not event:
610589
break
611590
breakpoint_events.append(event)
@@ -616,7 +595,7 @@ def wait_for_breakpoints_to_be_verified(
616595
):
617596
"""Wait for all breakpoints to be verified. Return all unverified breakpoints."""
618597
while any(id not in self.resolved_breakpoints for id in breakpoint_ids):
619-
breakpoint_event = self.wait_for_event("breakpoint", timeout=timeout)
598+
breakpoint_event = self.wait_for_event(["breakpoint"], timeout=timeout)
620599
if breakpoint_event is None:
621600
break
622601

@@ -625,18 +604,18 @@ def wait_for_breakpoints_to_be_verified(
625604
for id in breakpoint_ids
626605
if (
627606
id not in self.resolved_breakpoints
628-
or not self.resolved_breakpoints[id].is_verified()
607+
or not Breakpoint.is_verified(self.resolved_breakpoints[id])
629608
)
630609
]
631610

632611
def wait_for_exited(self, timeout: Optional[float] = None):
633-
event_dict = self.wait_for_event("exited", timeout=timeout)
612+
event_dict = self.wait_for_event(["exited"], timeout=timeout)
634613
if event_dict is None:
635614
raise ValueError("didn't get exited event")
636615
return event_dict
637616

638617
def wait_for_terminated(self, timeout: Optional[float] = None):
639-
event_dict = self.wait_for_event("terminated", timeout)
618+
event_dict = self.wait_for_event(["terminated"], timeout)
640619
if event_dict is None:
641620
raise ValueError("didn't get terminated event")
642621
return event_dict
@@ -990,7 +969,7 @@ def request_writeMemory(self, memoryReference, data, offset=0, allowPartial=Fals
990969
"type": "request",
991970
"arguments": args_dict,
992971
}
993-
return self.send_recv(command_dict)
972+
return self._send_recv(command_dict)
994973

995974
def request_evaluate(self, expression, frameIndex=0, threadId=None, context=None):
996975
stackFrame = self.get_stackFrame(frameIndex=frameIndex, threadId=threadId)
@@ -1041,7 +1020,7 @@ def request_initialize(self, sourceInitFile=False):
10411020
response = self._send_recv(command_dict)
10421021
if response:
10431022
if "body" in response:
1044-
self.capabilities = response["body"]
1023+
self.capabilities.update(response.get("body", {}))
10451024
return response
10461025

10471026
def request_launch(
@@ -1182,7 +1161,7 @@ def request_setBreakpoints(self, source: Source, line_array, data=None):
11821161
It contains optional location/hitCondition/logMessage parameters.
11831162
"""
11841163
args_dict = {
1185-
"source": source.as_dict(),
1164+
"source": source,
11861165
"sourceModified": False,
11871166
}
11881167
if line_array is not None:
@@ -1311,9 +1290,9 @@ def request_modules(
13111290
):
13121291
args_dict = {}
13131292

1314-
if start_module:
1293+
if start_module is not None:
13151294
args_dict["startModule"] = start_module
1316-
if module_count:
1295+
if module_count is not None:
13171296
args_dict["moduleCount"] = module_count
13181297

13191298
return self._send_recv(
@@ -1352,13 +1331,25 @@ def request_stackTrace(
13521331
print("[%3u] %s" % (idx, name))
13531332
return response
13541333

1355-
def request_source(self, sourceReference):
1334+
def request_source(
1335+
self, *, source: Optional[Source] = None, sourceReference: Optional[int] = None
1336+
):
13561337
"""Request a source from a 'Source' reference."""
1338+
if (
1339+
source is None
1340+
and sourceReference is None
1341+
or (source is not None and sourceReference is not None)
1342+
):
1343+
raise ValueError("request_source requires either source or sourceReference")
1344+
elif source:
1345+
sourceReference = source["sourceReference"]
1346+
elif sourceReference:
1347+
source = {"sourceReference": sourceReference}
13571348
command_dict = {
13581349
"command": "source",
13591350
"type": "request",
13601351
"arguments": {
1361-
"source": {"sourceReference": sourceReference},
1352+
"source": source,
13621353
# legacy version of the request
13631354
"sourceReference": sourceReference,
13641355
},

lldb/packages/Python/lldbsuite/test/tools/lldb-dap/lldbdap_testcase.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ def set_source_breakpoints_assembly(
6767
self, source_reference, lines, data=None, wait_for_resolve=True
6868
):
6969
return self.set_source_breakpoints_from_source(
70-
Source(source_reference=source_reference), lines, data, wait_for_resolve
70+
Source.build(source_reference=source_reference),
71+
lines,
72+
data,
73+
wait_for_resolve,
7174
)
7275

7376
def set_source_breakpoints_from_source(

lldb/test/API/tools/lldb-dap/breakpoint-assembly/TestDAP_breakpointAssembly.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Test lldb-dap setBreakpoints request in assembly source references.
33
"""
44

5-
65
from lldbsuite.test.decorators import *
76
from dap_server import Source
87
import lldbdap_testcase
@@ -52,7 +51,7 @@ def test_break_on_invalid_source_reference(self):
5251

5352
# Verify that setting a breakpoint on an invalid source reference fails
5453
response = self.dap_server.request_setBreakpoints(
55-
Source(source_reference=-1), [1]
54+
Source.build(source_reference=-1), [1]
5655
)
5756
self.assertIsNotNone(response)
5857
breakpoints = response["body"]["breakpoints"]
@@ -69,7 +68,7 @@ def test_break_on_invalid_source_reference(self):
6968

7069
# Verify that setting a breakpoint on a source reference that is not created fails
7170
response = self.dap_server.request_setBreakpoints(
72-
Source(source_reference=200), [1]
71+
Source.build(source_reference=200), [1]
7372
)
7473
self.assertIsNotNone(response)
7574
breakpoints = response["body"]["breakpoints"]
@@ -116,7 +115,7 @@ def test_persistent_assembly_breakpoint(self):
116115

117116
persistent_breakpoint_source = self.dap_server.resolved_breakpoints[
118117
persistent_breakpoint_ids[0]
119-
].source()
118+
]["source"]
120119
self.assertIn(
121120
"adapterData",
122121
persistent_breakpoint_source,
@@ -139,7 +138,7 @@ def test_persistent_assembly_breakpoint(self):
139138
self.dap_server.request_initialize()
140139
self.dap_server.request_launch(program)
141140
new_session_breakpoints_ids = self.set_source_breakpoints_from_source(
142-
Source(raw_dict=persistent_breakpoint_source),
141+
Source(persistent_breakpoint_source),
143142
[persistent_breakpoint_line],
144143
)
145144

lldb/test/API/tools/lldb-dap/breakpoint-events/TestDAP_breakpointEvents.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_breakpoint_events(self):
5858
# Set breakpoints and verify that they got set correctly
5959
dap_breakpoint_ids = []
6060
response = self.dap_server.request_setBreakpoints(
61-
Source(main_source_path), [main_bp_line]
61+
Source.build(path=main_source_path), [main_bp_line]
6262
)
6363
self.assertTrue(response["success"])
6464
breakpoints = response["body"]["breakpoints"]
@@ -70,7 +70,7 @@ def test_breakpoint_events(self):
7070
)
7171

7272
response = self.dap_server.request_setBreakpoints(
73-
Source(foo_source_path), [foo_bp1_line]
73+
Source.build(path=foo_source_path), [foo_bp1_line]
7474
)
7575
self.assertTrue(response["success"])
7676
breakpoints = response["body"]["breakpoints"]

0 commit comments

Comments
 (0)