Skip to content

Commit 17f330a

Browse files
authored
Merge pull request #380 from informatics-lab/fix-sqlite3-datetime
separate argument sanitization from query generation
2 parents c298063 + 8a9ec52 commit 17f330a

File tree

3 files changed

+166
-17
lines changed

3 files changed

+166
-17
lines changed

forest/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
.. automodule:: forest.presets
2424
2525
"""
26-
__version__ = '0.17.2'
26+
__version__ = '0.17.3'
2727

2828
from .config import *
2929
from . import (

forest/db/database.py

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
pass
66
import netCDF4
77
import jinja2
8+
import numpy as np
9+
import pandas as pd
810
from .connection import Connection
911

1012

@@ -373,10 +375,22 @@ def insert_pressure(self, path, variable, pressure, i):
373375

374376
def valid_times(self, pattern, variable, initial_time):
375377
"""Valid times associated with search criteria"""
378+
initial_time = self.sanitize_time(initial_time)
379+
query = self.valid_times_query(pattern, variable, initial_time)
380+
self.cursor.execute(query, dict(
381+
variable=variable,
382+
pattern=pattern,
383+
initial_time=initial_time))
384+
rows = self.cursor.fetchall()
385+
return [time for time, in rows]
386+
387+
@staticmethod
388+
def valid_times_query(pattern, variable, initial_time):
389+
"""Valid times SQL query syntax"""
376390
# Note: SQL injection possible if not properly escaped
377391
# use ? and :name syntax in template
378392
environment = jinja2.Environment(extensions=['jinja2.ext.do'])
379-
query = environment.from_string("""
393+
return environment.from_string("""
380394
{% set EQNS = [] %}
381395
{% if initial_time is not none %}
382396
{% do EQNS.append('file.reference = :initial_time') %}
@@ -402,19 +416,37 @@ def valid_times(self, pattern, variable, initial_time):
402416
initial_time=initial_time,
403417
variable=variable,
404418
pattern=pattern)
419+
420+
@staticmethod
421+
def sanitize_time(value):
422+
"""Query-compatible equivalent of value"""
423+
fmt = "%Y-%m-%d %H:%M:%S"
424+
if value is None:
425+
return value
426+
elif isinstance(value, str):
427+
return value
428+
elif isinstance(value, np.datetime64):
429+
return pd.to_datetime(str(value)).strftime(fmt)
430+
else:
431+
return value.strftime(fmt)
432+
433+
def pressures(self, pattern=None, variable=None, initial_time=None):
434+
"""Select pressures from database"""
435+
initial_time = self.sanitize_time(initial_time)
436+
query = self.pressures_query(pattern, variable, initial_time)
405437
self.cursor.execute(query, dict(
406438
variable=variable,
407439
pattern=pattern,
408440
initial_time=initial_time))
409441
rows = self.cursor.fetchall()
410442
return [time for time, in rows]
411443

412-
def pressures(self, pattern=None, variable=None, initial_time=None):
413-
"""Select pressures from database"""
444+
@staticmethod
445+
def pressures_query(pattern, variable, initial_time):
414446
# Note: SQL injection possible if not properly escaped
415447
# use ? and :name syntax in template
416448
environment = jinja2.Environment(extensions=['jinja2.ext.do'])
417-
query = environment.from_string("""
449+
return environment.from_string("""
418450
{% set EQNS = [] %}
419451
{% if variable is not none %}
420452
{% do EQNS.append('v.name = :variable') %}
@@ -445,12 +477,6 @@ def pressures(self, pattern=None, variable=None, initial_time=None):
445477
variable=variable,
446478
pattern=pattern,
447479
initial_time=initial_time)
448-
self.cursor.execute(query, dict(
449-
variable=variable,
450-
pattern=pattern,
451-
initial_time=initial_time))
452-
rows = self.cursor.fetchall()
453-
return [time for time, in rows]
454480

455481
def fetch_times(self, path, variable):
456482
"""Helper method to find times related to a variable"""

test/test_db_database.py

Lines changed: 129 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
from unittest.mock import Mock, sentinel
2+
import pytest
3+
import datetime as dt
4+
import cftime
5+
import numpy as np
26
import re
37

48
import forest.db.database as database
@@ -18,12 +22,21 @@ def _assert_query_and_params(db, expected_query, expected_params):
1822
db.cursor.execute.assert_called_once()
1923
args, kwargs = db.cursor.execute.call_args
2024
query, params = args
21-
query = re.sub(r'\s+', ' ', query).strip()
22-
assert query == expected_query
25+
assert_query_equal(query, expected_query)
2326
assert params == expected_params
2427
assert kwargs == {}
2528

2629

30+
def assert_query_equal(left, right):
31+
left, right = single_spaced(left), single_spaced(right)
32+
assert left == right
33+
34+
35+
def single_spaced(query):
36+
query = query.replace("\n", "")
37+
return re.sub(r'\s+', ' ', query).strip()
38+
39+
2740
def test_Database_valid_times__defaults():
2841
db = _create_db()
2942

@@ -39,7 +52,7 @@ def test_Database_valid_times__all_args():
3952
db = _create_db()
4053

4154
valid_times = db.valid_times(sentinel.pattern, sentinel.variable,
42-
sentinel.initial_time)
55+
dt.datetime(2020, 1, 1))
4356

4457
_assert_query_and_params(
4558
db, 'SELECT time.value FROM time'
@@ -49,10 +62,94 @@ def test_Database_valid_times__all_args():
4962
' WHERE file.reference = :initial_time'
5063
' AND file.name GLOB :pattern AND v.name = :variable',
5164
{'pattern': sentinel.pattern, 'variable': sentinel.variable,
52-
'initial_time':sentinel.initial_time})
65+
'initial_time': "2020-01-01 00:00:00"})
5366
assert valid_times == [sentinel.value1, sentinel.value2]
5467

5568

69+
@pytest.mark.parametrize("pattern, variable, initial_time, expect", [
70+
(None, None, None, """
71+
SELECT time.value FROM time
72+
"""),
73+
(sentinel.pattern, None, None, """
74+
SELECT time.value FROM time
75+
JOIN variable_to_time AS vt ON vt.time_id = time.id
76+
JOIN variable AS v ON vt.variable_id = v.id
77+
JOIN file ON v.file_id = file.id
78+
WHERE file.name GLOB :pattern
79+
"""),
80+
(None, sentinel.variable, None, """
81+
SELECT time.value FROM time
82+
JOIN variable_to_time AS vt ON vt.time_id = time.id
83+
JOIN variable AS v ON vt.variable_id = v.id
84+
JOIN file ON v.file_id = file.id
85+
WHERE v.name = :variable
86+
"""),
87+
(sentinel.pattern, sentinel.variable, None, """
88+
SELECT time.value FROM time
89+
JOIN variable_to_time AS vt ON vt.time_id = time.id
90+
JOIN variable AS v ON vt.variable_id = v.id
91+
JOIN file ON v.file_id = file.id
92+
WHERE file.name GLOB :pattern AND v.name = :variable
93+
"""),
94+
(sentinel.pattern, sentinel.variable, sentinel.initial_time, """
95+
SELECT time.value FROM time
96+
JOIN variable_to_time AS vt ON vt.time_id = time.id
97+
JOIN variable AS v ON vt.variable_id = v.id
98+
JOIN file ON v.file_id = file.id
99+
WHERE file.reference = :initial_time
100+
AND file.name GLOB :pattern AND v.name = :variable
101+
"""),
102+
])
103+
def test_valid_times_query(pattern, variable, initial_time, expect):
104+
result = database.Database.valid_times_query(pattern, variable, initial_time)
105+
assert_query_equal(expect, result)
106+
107+
108+
@pytest.mark.parametrize("pattern, variable, initial_time, expect", [
109+
(None, None, None, """
110+
SELECT DISTINCT value FROM pressure
111+
ORDER BY value
112+
"""),
113+
(sentinel.pattern, None, None, """
114+
SELECT DISTINCT pressure.value FROM pressure
115+
JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id
116+
JOIN variable AS v ON v.id = vp.variable_id
117+
JOIN file ON v.file_id = file.id
118+
WHERE file.name GLOB :pattern
119+
ORDER BY value
120+
"""),
121+
(None, sentinel.variable, None, """
122+
SELECT DISTINCT pressure.value FROM pressure
123+
JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id
124+
JOIN variable AS v ON v.id = vp.variable_id
125+
JOIN file ON v.file_id = file.id
126+
WHERE v.name = :variable
127+
ORDER BY value
128+
"""),
129+
(sentinel.pattern, sentinel.variable, None, """
130+
SELECT DISTINCT pressure.value FROM pressure
131+
JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id
132+
JOIN variable AS v ON v.id = vp.variable_id
133+
JOIN file ON v.file_id = file.id
134+
WHERE v.name = :variable AND file.name GLOB :pattern
135+
ORDER BY value
136+
"""),
137+
(sentinel.pattern, sentinel.variable, sentinel.initial_time, """
138+
SELECT DISTINCT pressure.value FROM pressure
139+
JOIN variable_to_pressure AS vp ON vp.pressure_id = pressure.id
140+
JOIN variable AS v ON v.id = vp.variable_id
141+
JOIN file ON v.file_id = file.id
142+
WHERE v.name = :variable
143+
AND file.name GLOB :pattern
144+
AND file.reference = :initial_time
145+
ORDER BY value
146+
"""),
147+
])
148+
def test_pressures_query(pattern, variable, initial_time, expect):
149+
result = database.Database.pressures_query(pattern, variable, initial_time)
150+
assert_query_equal(expect, result)
151+
152+
56153
def test_Database_pressures__defaults():
57154
db = _create_db()
58155

@@ -69,7 +166,7 @@ def test_Database_pressures__all_args():
69166
db = _create_db()
70167

71168
pressures = db.pressures(sentinel.pattern, sentinel.variable,
72-
sentinel.initial_time)
169+
dt.datetime(2020, 1, 1))
73170

74171
_assert_query_and_params(
75172
db, 'SELECT DISTINCT pressure.value FROM pressure'
@@ -80,5 +177,31 @@ def test_Database_pressures__all_args():
80177
' AND file.reference = :initial_time'
81178
' ORDER BY value',
82179
{'pattern': sentinel.pattern, 'variable': sentinel.variable,
83-
'initial_time':sentinel.initial_time})
180+
'initial_time': "2020-01-01 00:00:00"})
84181
assert pressures == [sentinel.value1, sentinel.value2]
182+
183+
184+
185+
@pytest.mark.parametrize("initial_time", [
186+
pytest.param(dt.datetime(2020, 1, 1), id="datetime"),
187+
pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"),
188+
pytest.param(np.datetime64("2020-01-01", "ns"), id="np.datetime64"),
189+
])
190+
def test_Database_valid_times_given_datetime_like_objects(initial_time):
191+
initial_datetime = dt.datetime(2020, 1, 1)
192+
valid_times = [dt.datetime(2020, 1, 1, 12)]
193+
db = database.Database.connect(":memory:")
194+
db.insert_file_name("file.nc", initial_datetime)
195+
db.insert_times("file.nc", "air_temperature", valid_times)
196+
result = db.valid_times("file.nc", "air_temperature", initial_time)
197+
expect = ["2020-01-01 12:00:00"]
198+
assert expect == result
199+
200+
201+
@pytest.mark.parametrize("time", [
202+
pytest.param(dt.datetime(2020, 1, 1), id="datetime"),
203+
pytest.param(cftime.DatetimeGregorian(2020, 1, 1), id="cftime"),
204+
pytest.param(np.datetime64("2020-01-01", "ns"), id="np.datetime64"),
205+
])
206+
def test_Database_sanitize_datetime_like_objects(time):
207+
assert database.Database.sanitize_time(time) == "2020-01-01 00:00:00"

0 commit comments

Comments
 (0)