Skip to content

Commit 9985d70

Browse files
committed
run osm2pgsql-replication in BDD tests as a module
This gives us the possibility to monkeypatch parts of the code.
1 parent deae7e0 commit 9985d70

File tree

3 files changed

+28
-14
lines changed

3 files changed

+28
-14
lines changed

scripts/osm2pgsql-replication

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -656,9 +656,12 @@ def get_parser():
656656
return parser
657657

658658

659-
def main():
659+
def main(prog_args=None):
660660
parser = get_parser()
661-
args = parser.parse_args()
661+
try:
662+
args = parser.parse_args(args=prog_args)
663+
except SystemExit:
664+
return 1
662665

663666
if missing_modules:
664667
LOG.fatal("Missing required Python libraries %(mods)s.\n\n"

tests/bdd/environment.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from pathlib import Path
99
import subprocess
1010
import tempfile
11+
import importlib.util
12+
from importlib.machinery import SourceFileLoader
1113

1214
from behave import *
1315
import psycopg2
@@ -89,6 +91,14 @@ def before_all(context):
8991
context.test_data_dir = Path(context.config.userdata['TEST_DATA_DIR']).resolve()
9092
context.default_data_dir = Path(context.config.userdata['SRC_DIR']).resolve()
9193

94+
# Set up replication script.
95+
replicationfile = str(Path(context.config.userdata['REPLICATION_SCRIPT']).resolve())
96+
spec = importlib.util.spec_from_loader('osm2pgsql_replication',
97+
SourceFileLoader( 'osm2pgsql_replication',replicationfile))
98+
assert spec, f"File not found: {replicationfile}"
99+
context.osm2pgsql_replication = importlib.util.module_from_spec(spec)
100+
spec.loader.exec_module(context.osm2pgsql_replication)
101+
92102

93103
def before_scenario(context, scenario):
94104
""" Set up a fresh, empty test database.

tests/bdd/steps/steps_execute.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@
99
"""
1010
from io import StringIO
1111
from pathlib import Path
12-
import os
12+
import sys
1313
import subprocess
14+
import contextlib
15+
import logging
1416

1517
def get_import_file(context):
1618
if context.import_file is not None:
@@ -76,7 +78,7 @@ def run_osm2pgsql(context, output):
7678

7779

7880
def run_osm2pgsql_replication(context):
79-
cmdline = [str(Path(context.config.userdata['REPLICATION_SCRIPT']).resolve())]
81+
cmdline = []
8082
# convert table items to CLI arguments and inject constants to placeholders
8183
if context.table:
8284
cmdline.extend(f.format(**context.config.userdata) for f in context.table.headings if f)
@@ -86,19 +88,18 @@ def run_osm2pgsql_replication(context):
8688
if '-d' not in cmdline and '--database' not in cmdline:
8789
cmdline.extend(('-d', context.config.userdata['TEST_DB']))
8890

89-
# on Windows execute script directly with python, because shebang is not recognised
90-
if os.name == 'nt':
91-
cmdline.insert(0, "python")
9291

93-
proc = subprocess.Popen(cmdline, cwd=str(context.workdir),
94-
stdin=subprocess.PIPE,
95-
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
92+
serr = StringIO()
93+
log_handler = logging.StreamHandler(serr)
94+
context.osm2pgsql_replication.LOG.addHandler(log_handler)
95+
with contextlib.redirect_stdout(StringIO()) as sout:
96+
retval = context.osm2pgsql_replication.main(cmdline)
97+
context.osm2pgsql_replication.LOG.removeHandler(log_handler)
9698

97-
outdata = proc.communicate()
99+
context.osm2pgsql_outdata = [sout.getvalue(), serr.getvalue()]
100+
print(context.osm2pgsql_outdata)
98101

99-
context.osm2pgsql_outdata = [d.decode('utf-8').replace('\\n', '\n') for d in outdata]
100-
101-
return proc.returncode
102+
return retval
102103

103104

104105
@given("no lua tagtransform")

0 commit comments

Comments
 (0)