Skip to content

Commit b1229ff

Browse files
authored
Merge pull request #842 from jluebbe/export
support exporting resource/driver information
2 parents e3180cb + 193ebaf commit b1229ff

File tree

11 files changed

+296
-2
lines changed

11 files changed

+296
-2
lines changed

labgrid/binding.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,22 @@ def wrapper(self, *_args, **_kwargs):
9797

9898
return wrapper
9999

100+
@classmethod
101+
def check_bound(cls, func):
102+
@wraps(func)
103+
def wrapper(self, *_args, **_kwargs):
104+
if self.state is BindingState.active:
105+
raise StateError(
106+
f'{self} is active, but must be deactivated to call {func.__qualname__}' # pylint: disable=line-too-long
107+
)
108+
elif self.state is not BindingState.bound:
109+
raise StateError(
110+
f'{self} has not been bound, {func.__qualname__} cannot be called in state "{self.state.name}"' # pylint: disable=line-too-long
111+
)
112+
return func(self, *_args, **_kwargs)
113+
114+
return wrapper
115+
100116
class NamedBinding:
101117
"""
102118
Marks a binding (or binding set) as requiring an explicit name.

labgrid/driver/common.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,19 @@ def get_priority(self, protocol):
4545

4646
return 0
4747

48+
def get_export_name(self):
49+
"""Get the name to be used for exported variables.
50+
51+
Falls back to the class name if the driver has no name.
52+
"""
53+
if self.name:
54+
return self.name
55+
return self.__class__.__name__
56+
57+
def get_export_vars(self):
58+
"""Get a dictionary of variables to be exported."""
59+
return {}
60+
4861

4962
def check_file(filename, *, command_prefix=[]):
5063
if subprocess.call(command_prefix + ['test', '-r', filename]) != 0:

labgrid/driver/networkinterfacedriver.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@ def on_deactivate(self):
3535
self.wrapper = None
3636
self.proxy = None
3737

38+
@Driver.check_bound
39+
def get_export_vars(self):
40+
return {
41+
"host": self.iface.host,
42+
"ifname": self.iface.ifname or "",
43+
}
44+
3845
# basic
3946
@Driver.check_active
4047
@step()

labgrid/driver/provider.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ class BaseProviderDriver(Driver):
1414
def __attrs_post_init__(self):
1515
super().__attrs_post_init__()
1616

17+
@Driver.check_bound
18+
def get_export_vars(self):
19+
return {
20+
"host": self.provider.host,
21+
"internal": self.provider.internal,
22+
"external": self.provider.external,
23+
}
24+
1725
@Driver.check_active
1826
@step(args=['filename'], result=True)
1927
def stage(self, filename):

labgrid/driver/serialdriver.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,20 @@ def on_activate(self):
6969
def on_deactivate(self):
7070
self.close()
7171

72+
@Driver.check_bound
73+
def get_export_vars(self):
74+
vars = {
75+
"speed": str(self.port.speed)
76+
}
77+
if isinstance(self.port, SerialPort):
78+
vars["port"] = self.port.port
79+
else:
80+
host, port = proxymanager.get_host_and_port(self.port)
81+
vars["host"] = host
82+
vars["port"] = str(port)
83+
vars["protocol"] = self.port.protocol
84+
return vars
85+
7286
def _read(self, size: int = 1, timeout: float = 0.0):
7387
"""
7488
Reads 'size' or more bytes from the serialport

labgrid/exceptions.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ class NoResourceFoundError(NoSupplierFoundError):
3030
pass
3131

3232

33+
@attr.s(eq=False)
34+
class NoStrategyFoundError(NoSupplierFoundError):
35+
pass
36+
37+
3338
@attr.s(eq=False)
3439
class RegistrationError(Exception):
3540
msg = attr.ib(validator=attr.validators.instance_of(str))

labgrid/remote/client.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
coordinator, acquire a place and interact with the connected resources"""
33
import argparse
44
import asyncio
5+
import atexit
56
import contextlib
7+
import enum
68
import os
79
import subprocess
810
import traceback
911
import logging
1012
import signal
1113
import sys
14+
import shlex
15+
import json
1216
from textwrap import indent
1317
from socket import gethostname
1418
from getpass import getuser
@@ -27,6 +31,7 @@
2731
from .. import Target, target_factory
2832
from ..util.proxy import proxymanager
2933
from ..util.helper import processwrapper
34+
from ..util import atomic_replace
3035
from ..driver import Mode
3136

3237
txaio.use_asyncio()
@@ -1327,6 +1332,34 @@ async def print_reservations(self):
13271332
print(f"Reservation '{res.token}':")
13281333
res.show(level=1)
13291334

1335+
async def export(self, place, target):
1336+
exported = target.export()
1337+
exported["LG__CLIENT_PID"] = str(os.getpid())
1338+
if self.args.format is ExportFormat.SHELL:
1339+
lines = []
1340+
for k, v in sorted(exported.items()):
1341+
lines.append(f"{k}={shlex.quote(v)}")
1342+
data = "\n".join(lines)
1343+
elif self.args.format is ExportFormat.SHELL_EXPORT:
1344+
lines = []
1345+
for k, v in sorted(exported.items()):
1346+
lines.append(f"export {k}={shlex.quote(v)}")
1347+
data = "\n".join(lines)+"\n"
1348+
elif self.args.format is ExportFormat.JSON:
1349+
data = json.dumps(exported)
1350+
if self.args.filename == "-":
1351+
sys.stdout.write(data)
1352+
else:
1353+
atomic_replace(self.args.filename, data.encode())
1354+
print(f"Exported to {self.args.filename}", file=sys.stderr)
1355+
try:
1356+
print("Waiting for CTRL+C or SIGTERM...", file=sys.stderr)
1357+
while True:
1358+
await asyncio.sleep(1.0)
1359+
except GeneratorExit:
1360+
print("Exiting...\n", file=sys.stderr)
1361+
export.needs_target = True
1362+
13301363

13311364
def start_session(url, realm, extra):
13321365
from autobahn.asyncio.wamp import ApplicationRunner
@@ -1421,6 +1454,16 @@ def __call__(self, parser, namespace, value, option_string):
14211454
v.append((local, remote))
14221455
setattr(namespace, self.dest, v)
14231456

1457+
1458+
class ExportFormat(enum.Enum):
1459+
SHELL = "shell"
1460+
SHELL_EXPORT = "shell-export"
1461+
JSON = "json"
1462+
1463+
def __str__(self):
1464+
return self.value
1465+
1466+
14241467
def main():
14251468
processwrapper.enable_logging()
14261469
logging.basicConfig(
@@ -1756,6 +1799,13 @@ def main():
17561799
subparser = subparsers.add_parser('reservations', help="list current reservations")
17571800
subparser.set_defaults(func=ClientSession.print_reservations)
17581801

1802+
subparser = subparsers.add_parser('export', help="export driver information to a file (needs environment with drivers)")
1803+
subparser.add_argument('--format', dest='format',
1804+
type=ExportFormat, choices=ExportFormat, default=ExportFormat.SHELL_EXPORT,
1805+
help="output format (default: %(default)s)")
1806+
subparser.add_argument('filename', help='output filename')
1807+
subparser.set_defaults(func=ClientSession.export)
1808+
17591809
# make any leftover arguments available for some commands
17601810
args, leftover = parser.parse_known_args()
17611811
if args.command not in ['ssh', 'rsync', 'forward']:
@@ -1806,6 +1856,8 @@ def main():
18061856
if args.command and args.command != 'help':
18071857
exitcode = 0
18081858
try:
1859+
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
1860+
18091861
session = start_session(args.crossbar, os.environ.get("LG_CROSSBAR_REALM", "realm1"),
18101862
extra)
18111863
try:

labgrid/strategy/common.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,10 @@ def transition(self, status):
4848

4949
def force(self, status):
5050
raise NotImplementedError(f"Strategy.force() is not implemented for {self.__class__.__name__}")
51+
52+
def prepare_export(self):
53+
"""By default, export all drivers bound by the strategy."""
54+
name_map = {}
55+
for name in self.bindings.keys():
56+
name_map[getattr(self, name)] = name
57+
return name_map

labgrid/target.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from .binding import BindingError, BindingState
99
from .driver import Driver
10-
from .exceptions import NoSupplierFoundError, NoDriverFoundError, NoResourceFoundError
10+
from .exceptions import NoSupplierFoundError, NoDriverFoundError, NoResourceFoundError, NoStrategyFoundError
1111
from .resource import Resource
1212
from .strategy import Strategy
1313
from .util import Timeout
@@ -212,6 +212,24 @@ def get_driver(self, cls, *, name=None, activate=True):
212212
"""
213213
return self._get_driver(cls, name=name, activate=activate)
214214

215+
def get_strategy(self):
216+
"""
217+
Helper function to get the strategy of the target.
218+
219+
Returns the Strategy, if exactly one exists and raises a
220+
NoStrategyFoundError otherwise.
221+
"""
222+
found = []
223+
for drv in self.drivers:
224+
if not isinstance(drv, Strategy):
225+
continue
226+
found.append(drv)
227+
if not found:
228+
raise NoStrategyFoundError(f"no Strategy found in {self}")
229+
elif len(found) > 1:
230+
raise NoStrategyFoundError(f"multiple Strategies found in {self}")
231+
return found[0]
232+
215233
def __getitem__(self, key):
216234
"""
217235
Syntactic sugar to access drivers by class (optionally filtered by
@@ -469,6 +487,39 @@ def _atexit_cleanup(self):
469487
"method on targets yourself to handle exceptions explictly.")
470488
print(f"Error: {e}")
471489

490+
def export(self):
491+
"""
492+
Export information from drivers.
493+
494+
All drivers are deactivated before being exported.
495+
496+
The Strategy can decide for which driver the export method is called and
497+
with which name. Otherwise, all drivers are exported.
498+
"""
499+
try:
500+
name_map = self.get_strategy().prepare_export()
501+
selection = set(name_map.keys())
502+
except NoStrategyFoundError:
503+
name_map = {}
504+
selection = set(driver for driver in self.drivers if not isinstance(driver, Strategy))
505+
506+
assert len(name_map) == len(set(name_map.values())), "duplicate export name"
507+
508+
# drivers need to be deactivated for export to avoid conflicts
509+
self.deactivate_all_drivers()
510+
511+
export_vars = {}
512+
for driver in selection:
513+
name = name_map.get(driver)
514+
if not name:
515+
name = driver.get_export_name()
516+
for k, v in driver.get_export_vars().items():
517+
assert isinstance(k, str), f"key {k} from {driver} is not a string"
518+
assert isinstance(v, str), f"value {v} for key {k} from {driver} is not a string"
519+
export_vars[f"LG__{name}_{k}".upper()] = v
520+
return export_vars
521+
522+
472523
def cleanup(self):
473524
"""Clean up conntected drivers and resources in reversed order"""
474525
self.deactivate_all_drivers()

tests/test_export.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import pytest
2+
3+
from labgrid.resource import Resource, NetworkSerialPort
4+
from labgrid.resource.remote import RemoteNetworkInterface, RemoteTFTPProvider
5+
from labgrid.driver import Driver, SerialDriver, NetworkInterfaceDriver, TFTPProviderDriver
6+
from labgrid.strategy import Strategy
7+
from labgrid.binding import StateError
8+
9+
10+
class ResourceA(Resource):
11+
pass
12+
13+
14+
class DriverA(Driver):
15+
bindings = {"res": ResourceA}
16+
17+
@Driver.check_bound
18+
def get_export_vars(self):
19+
return {
20+
"a": "b",
21+
}
22+
23+
24+
class StrategyA(Strategy):
25+
bindings = {
26+
"drv": DriverA,
27+
}
28+
29+
30+
def test_export(target):
31+
ra = ResourceA(target, "resource")
32+
d = DriverA(target, "driver")
33+
s = StrategyA(target, "strategy")
34+
35+
exported = target.export()
36+
assert exported == {
37+
"LG__DRV_A": "b",
38+
}
39+
40+
target.activate(d)
41+
with pytest.raises(StateError):
42+
d.get_export_vars()
43+
44+
45+
class StrategyB(Strategy):
46+
bindings = {
47+
"drv": DriverA,
48+
}
49+
50+
def prepare_export(self):
51+
return {
52+
self.drv: "custom_name",
53+
}
54+
55+
56+
def test_export_custom(target):
57+
ra = ResourceA(target, "resource")
58+
d = DriverA(target, "driver")
59+
s = StrategyB(target, "strategy")
60+
61+
exported = target.export()
62+
assert exported == {
63+
"LG__CUSTOM_NAME_A": "b",
64+
}
65+
66+
67+
def test_export_network_serial(target):
68+
NetworkSerialPort(target, None, host='testhost', port=12345, speed=115200)
69+
SerialDriver(target, None)
70+
71+
exported = target.export()
72+
assert exported == {
73+
'LG__SERIALDRIVER_HOST': 'testhost',
74+
'LG__SERIALDRIVER_PORT': '12345',
75+
'LG__SERIALDRIVER_PROTOCOL': 'rfc2217',
76+
'LG__SERIALDRIVER_SPEED': '115200'
77+
}
78+
79+
80+
def test_export_remote_network_interface(target):
81+
RemoteNetworkInterface(target, None, host='testhost', ifname='wlan0')
82+
NetworkInterfaceDriver(target, "netif")
83+
84+
exported = target.export()
85+
assert exported == {
86+
'LG__NETIF_HOST': 'testhost',
87+
'LG__NETIF_IFNAME': 'wlan0'
88+
}
89+
90+
91+
def test_export_remote_tftp_provider(target):
92+
RemoteTFTPProvider(target, None, host='testhost', internal='/srv/tftp/testboard/', external='testboard/')
93+
TFTPProviderDriver(target, "tftp")
94+
95+
exported = target.export()
96+
assert exported == {
97+
'LG__TFTP_HOST': 'testhost',
98+
'LG__TFTP_INTERNAL': '/srv/tftp/testboard/',
99+
'LG__TFTP_EXTERNAL': 'testboard/',
100+
}

0 commit comments

Comments
 (0)