Skip to content

Commit 49f1868

Browse files
committed
use psycopg SQL module for correct quoting
Resolves a potentional SQL injection issue with the prefix parameter.
1 parent 797ceaa commit 49f1868

File tree

1 file changed

+19
-19
lines changed

1 file changed

+19
-19
lines changed

scripts/osm2pgsql-replication

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ missing_modules = []
3434

3535
try:
3636
import psycopg2 as psycopg
37+
from psycopg2 import sql
3738
except ImportError:
3839
try:
3940
import psycopg
41+
from psycopg import sql
4042
except ImportError:
4143
missing_modules.append('psycopg2')
4244

@@ -90,7 +92,8 @@ def compute_database_date(conn, prefix):
9092
# First, find the way with the highest ID in the database
9193
# Using nodes would be more reliable but those are not cached by osm2pgsql.
9294
with conn.cursor() as cur:
93-
cur.execute("SELECT max(id) FROM {}_ways".format(prefix))
95+
table = sql.Identifier(f'{prefix}_ways')
96+
cur.execute(sql.SQL("SELECT max(id) FROM {}").format(table))
9497
osmid = cur.fetchone()[0] if cur.rowcount == 1 else None
9598

9699
if osmid is None:
@@ -127,13 +130,13 @@ def setup_replication_state(conn, table, base_url, seq, date):
127130
the given state.
128131
"""
129132
with conn.cursor() as cur:
130-
cur.execute('DROP TABLE IF EXISTS "{}"'.format(table))
131-
cur.execute("""CREATE TABLE "{}"
132-
(url TEXT,
133-
sequence INTEGER,
134-
importdate TIMESTAMP WITH TIME ZONE)
135-
""".format(table))
136-
cur.execute('INSERT INTO "{}" VALUES(%s, %s, %s)'.format(table),
133+
cur.execute(sql.SQL('DROP TABLE IF EXISTS {}').format(table))
134+
cur.execute(sql.SQL("""CREATE TABLE {}
135+
(url TEXT,
136+
sequence INTEGER,
137+
importdate TIMESTAMP WITH TIME ZONE)
138+
""").format(table))
139+
cur.execute(sql.SQL('INSERT INTO {} VALUES(%s, %s, %s)').format(table),
137140
(base_url, seq, date))
138141
conn.commit()
139142

@@ -144,10 +147,10 @@ def update_replication_state(conn, table, seq, date):
144147
"""
145148
with conn.cursor() as cur:
146149
if date is not None:
147-
cur.execute('UPDATE "{}" SET sequence=%s, importdate=%s'.format(table),
150+
cur.execute(sql.SQL('UPDATE {} SET sequence=%s, importdate=%s').format(table),
148151
(seq, date))
149152
else:
150-
cur.execute('UPDATE "{}" SET sequence=%s'.format(table),
153+
cur.execute(sql.SQL('UPDATE {} SET sequence=%s').format(table),
151154
(seq,))
152155

153156
conn.commit()
@@ -197,12 +200,12 @@ def status(conn, args):
197200
results = {}
198201

199202
with conn.cursor() as cur:
200-
cur.execute('SELECT * FROM pg_tables where tablename = %s', (args.table, ))
203+
cur.execute('SELECT * FROM pg_tables where tablename = %s', (args.table_name, ))
201204
if cur.rowcount < 1:
202205
results['status'] = 1
203206
results['error'] = "Cannot find replication status table. Run 'osm2pgsql-replication init' first."
204207
else:
205-
cur.execute('SELECT * FROM "{}"'.format(args.table))
208+
cur.execute(sql.SQL('SELECT * FROM {}').format(args.table))
206209
if cur.rowcount != 1:
207210
results['status'] = 2
208211
results['error'] = "Updates not set up correctly. Run 'osm2pgsql-updates init' first."
@@ -342,13 +345,13 @@ def update(conn, args):
342345
after the updates have been downloaded.
343346
"""
344347
with conn.cursor() as cur:
345-
cur.execute('SELECT * FROM pg_tables where tablename = %s', (args.table, ))
348+
cur.execute('SELECT * FROM pg_tables where tablename = %s', (args.table_name, ))
346349
if cur.rowcount < 1:
347350
LOG.fatal("Cannot find replication status table. "
348351
"Run 'osm2pgsql-replication init' first.")
349352
return 1
350353

351-
cur.execute('SELECT * FROM "{}"'.format(args.table))
354+
cur.execute(sql.SQL('SELECT * FROM {}').format(args.table))
352355
if cur.rowcount != 1:
353356
LOG.fatal("Updates not set up correctly. Run 'osm2pgsql-updates init' first.")
354357
return 1
@@ -518,11 +521,8 @@ def main():
518521
datefmt='%Y-%m-%d %H:%M:%S',
519522
level=max(4 - args.verbose, 1) * 10)
520523

521-
if '"' in args.prefix:
522-
LOG.fatal("Prefix must not contain quotation marks.")
523-
return 1
524-
525-
args.table = '{}_replication_status'.format(args.prefix)
524+
args.table_name = f'{args.prefix}_replication_status'
525+
args.table = sql.Identifier(args.table_name)
526526

527527
conn = connect(args)
528528
ret = args.handler(conn, args)

0 commit comments

Comments
 (0)