Skip to content

Commit e6c2f36

Browse files
author
David Erb
committed
fixes revision upgrade logic
1 parent 55decae commit e6c2f36

File tree

4 files changed

+145
-99
lines changed

4 files changed

+145
-99
lines changed

src/dls_normsql/aiomysql.py

Lines changed: 20 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,6 @@ def __init__(self, specification, database_definition_object):
7777

7878
self.__tables = {}
7979

80-
# Deriving class has not established its latest revision?
81-
if not hasattr(self, "LATEST_REVISION") or self.LATEST_REVISION is None:
82-
# Presume it is 1.
83-
self.LATEST_REVISION = 1
84-
8580
self.__backup_restore_lock = asyncio.Lock()
8681

8782
# Last undo position.
@@ -147,12 +142,13 @@ async def connect(self, should_drop_database=False):
147142
if should_create_schemas:
148143
await self.create_schemas()
149144
await self.insert(
150-
Tablenames.REVISION, [{"number": self.LATEST_REVISION}]
145+
Tablenames.REVISION,
146+
[{"number": self.__database_definition_object.LATEST_REVISION}],
151147
)
152148

153149
# Emit the name of the database file for positive confirmation on console.
154150
logger.info(
155-
f"{callsign(self)} database name is {self.__database_name} code revision {self.LATEST_REVISION}"
151+
f"{callsign(self)} database name is {self.__database_name} database definition revision {self.__database_definition_object.LATEST_REVISION}"
156152
)
157153

158154
# ----------------------------------------------------------------------------------------
@@ -170,36 +166,38 @@ async def apply_revisions(self):
170166
why="get database revision",
171167
)
172168
if len(records) == 0:
173-
old_revision = 0
169+
database_revision = 0
174170
else:
175-
old_revision = records[0]["number"]
171+
database_revision = records[0]["number"]
176172
except Exception as exception:
177173
logger.warning(
178174
f"could not get revision, presuming legacy database with no table: {exception}"
179175
)
180-
old_revision = 0
176+
database_revision = 0
181177

182-
if old_revision < self.LATEST_REVISION:
178+
if database_revision < self.__database_definition_object.LATEST_REVISION:
183179
# Backup before applying revisions.
184180
logger.debug(
185-
f"[BKREVL] backing up before updating from revision {old_revision} to revision {self.LATEST_REVISION}"
181+
f"[BKREVL] backing up before updating from database revision {database_revision}"
182+
f" to definition revision {self.__database_definition_object.LATEST_REVISION}"
186183
)
187184

188185
await self.backup()
189186

190-
for revision in range(old_revision, self.LATEST_REVISION):
191-
logger.debug(f"updating to revision {revision+1}")
187+
for revision in range(
188+
database_revision, self.__database_definition_object.LATEST_REVISION
189+
):
192190
await self.apply_revision(revision + 1)
193191
await self.update(
194192
Tablenames.REVISION,
195-
{"number": self.LATEST_REVISION},
193+
{"number": self.__database_definition_object.LATEST_REVISION},
196194
"1 = 1",
197195
why="update database revision",
198196
)
199197
else:
200198
logger.debug(
201-
f"[BKREVL] no need to update persistent revision {old_revision}"
202-
f" which matches code revision {self.LATEST_REVISION}"
199+
f"[BKREVL] no need to update database revision {database_revision}"
200+
f" which matches definition revision {self.__database_definition_object.LATEST_REVISION}"
203201
)
204202

205203
# ----------------------------------------------------------------------------------------
@@ -212,6 +210,9 @@ async def apply_revision(self, revision):
212210
await self.create_table(Tablenames.REVISION)
213211
await self.insert(Tablenames.REVISION, [{"revision": revision}])
214212

213+
# Let the database definition object do its thing.
214+
await self.__database_definition_object.apply_revision(self, revision)
215+
215216
# ----------------------------------------------------------------------------------------
216217
async def disconnect(self):
217218

@@ -551,39 +552,7 @@ async def backup(self):
551552
"""
552553

553554
async with self.__backup_restore_lock:
554-
# Prune all the restores which were orphaned.
555-
directory = self.__backup_directory
556-
if directory is None:
557-
raise RuntimeError("no backup directory supplied in confirmation")
558-
559-
basename, suffix = os.path.splitext(os.path.basename(self.__database_name))
560-
561-
filenames = glob.glob(f"{directory}/{basename}.*{suffix}")
562-
563-
filenames.sort(reverse=True)
564-
565-
for restore in range(self.__last_restore):
566-
logger.debug(
567-
f"[BACKPRU] removing {restore}-th restore {filenames[restore]}"
568-
)
569-
os.remove(filenames[restore])
570-
571-
self.__last_restore = 0
572-
573-
timestamp = isodatetime_filename()
574-
to_filename = f"{directory}/{basename}.{timestamp}{suffix}"
575-
576-
await self.disconnect()
577-
try:
578-
await self.__create_directory(to_filename)
579-
shutil.copy2(self.__database_name, to_filename)
580-
logger.debug(f"backed up to {to_filename}")
581-
except Exception:
582-
raise RuntimeError(
583-
f"copy {self.__database_name} to {to_filename} failed"
584-
)
585-
finally:
586-
await self.connect()
555+
pass
587556

588557
# ----------------------------------------------------------------------------------------
589558
async def restore(self, nth):
@@ -592,37 +561,7 @@ async def restore(self, nth):
592561
"""
593562

594563
async with self.__backup_restore_lock:
595-
directory = self.__backup_directory
596-
if directory is None:
597-
raise RuntimeError("no backup directory supplied in confirmation")
598-
599-
basename, suffix = os.path.splitext(os.path.basename(self.__database_name))
600-
601-
filenames = glob.glob(f"{directory}/{basename}.*{suffix}")
602-
603-
filenames.sort(reverse=True)
604-
605-
if nth >= len(filenames):
606-
raise RuntimeError(
607-
f"restoration index {nth} is more than available {len(filenames)}"
608-
)
609-
610-
from_filename = filenames[nth]
611-
612-
await self.disconnect()
613-
try:
614-
shutil.copy2(from_filename, self.__database_name)
615-
logger.debug(
616-
f"restored nth {nth} out of {len(filenames)} from {from_filename}"
617-
)
618-
except Exception:
619-
raise RuntimeError(
620-
f"copy {from_filename} to {self.__database_name} failed"
621-
)
622-
finally:
623-
await self.connect()
624-
625-
self.__last_restore = nth
564+
pass
626565

627566

628567
# ----------------------------------------------------------------------------------------

src/dls_normsql/aiosqlite.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,11 +65,6 @@ def __init__(self, specification, database_definition_object):
6565

6666
self.__tables = {}
6767

68-
# Deriving class has not established its latest revision?
69-
if not hasattr(self, "LATEST_REVISION") or self.LATEST_REVISION is None:
70-
# Presume it is 1.
71-
self.LATEST_REVISION = 1
72-
7368
self.__backup_restore_lock = asyncio.Lock()
7469

7570
# Last undo position.
@@ -126,14 +121,15 @@ async def connect(self, should_drop_database=False):
126121
if should_create_schemas:
127122
await self.create_schemas()
128123
await self.insert(
129-
Tablenames.REVISION, [{"number": self.LATEST_REVISION}]
124+
Tablenames.REVISION,
125+
[{"number": self.__database_definition_object.LATEST_REVISION}],
130126
)
131127
# TODO: Set permission on sqlite file from configuration.
132128
os.chmod(self.__filename, 0o666)
133129

134130
# Emit the name of the database file for positive confirmation on console.
135131
logger.info(
136-
f"{callsign(self)} database file is {self.__filename} code revision {self.LATEST_REVISION}"
132+
f"{callsign(self)} database file is {self.__filename} database definition revision {self.__database_definition_object.LATEST_REVISION}"
137133
)
138134

139135
# ----------------------------------------------------------------------------------------
@@ -151,36 +147,38 @@ async def apply_revisions(self):
151147
why="get database revision",
152148
)
153149
if len(records) == 0:
154-
old_revision = 0
150+
database_revision = 0
155151
else:
156-
old_revision = records[0]["number"]
152+
database_revision = records[0]["number"]
157153
except Exception as exception:
158154
logger.warning(
159155
f"could not get revision, presuming legacy database with no table: {exception}"
160156
)
161-
old_revision = 0
157+
database_revision = 0
162158

163-
if old_revision < self.LATEST_REVISION:
159+
if database_revision < self.__database_definition_object.LATEST_REVISION:
164160
# Backup before applying revisions.
165161
logger.debug(
166-
f"[BKREVL] backing up before updating from revision {old_revision} to revision {self.LATEST_REVISION}"
162+
f"[BKREVL] backing up before updating from database revision {database_revision}"
163+
f" to definition revision {self.__database_definition_object.LATEST_REVISION}"
167164
)
168165

169166
await self.backup()
170167

171-
for revision in range(old_revision, self.LATEST_REVISION):
172-
logger.debug(f"updating to revision {revision+1}")
168+
for revision in range(
169+
database_revision, self.__database_definition_object.LATEST_REVISION
170+
):
173171
await self.apply_revision(revision + 1)
174172
await self.update(
175173
Tablenames.REVISION,
176-
{"number": self.LATEST_REVISION},
174+
{"number": self.__database_definition_object.LATEST_REVISION},
177175
"1 = 1",
178176
why="update database revision",
179177
)
180178
else:
181179
logger.debug(
182-
f"[BKREVL] no need to update persistent revision {old_revision}"
183-
f" which matches code revision {self.LATEST_REVISION}"
180+
f"[BKREVL] no need to update database revision {database_revision}"
181+
f" which matches definition revision {self.__database_definition_object.LATEST_REVISION}"
184182
)
185183

186184
# ----------------------------------------------------------------------------------------
@@ -194,7 +192,7 @@ async def apply_revision(self, revision):
194192
await self.insert(Tablenames.REVISION, [{"revision": revision}])
195193

196194
# Let the database definition object do its thing.
197-
self.__database_definition_object.apply_revisions(self)
195+
await self.__database_definition_object.apply_revision(self, revision)
198196

199197
# ----------------------------------------------------------------------------------------
200198
async def disconnect(self):

tests/my_database_definition.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,13 @@ async def apply_revision(self, database, revision):
2525

2626
logger.debug(f"applying revision {revision}")
2727

28+
if revision == 4:
29+
await database.execute("CREATE TABLE `my_table2` (`number` INTEGER)")
30+
31+
await database.execute(
32+
f"ALTER TABLE my_table2 ADD COLUMN string TEXT",
33+
)
34+
2835
# ----------------------------------------------------------------------------------------
2936
async def add_table_definitions(self, database):
3037
"""

tests/test_revision.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import logging
2+
3+
import pytest
4+
from dls_utilpack.envvar import Envvar
5+
6+
from dls_normsql.constants import ClassTypes, RevisionFieldnames, Tablenames
7+
from dls_normsql.databases import Databases
8+
from tests.base_tester import BaseTester
9+
from tests.my_database_definition import MyDatabaseDefinition
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
# ----------------------------------------------------------------------------------------
15+
class TestRevisionAiosqlite:
16+
def test(self, logging_setup, output_directory):
17+
"""
18+
Tests the sqlite implementation of Database.
19+
"""
20+
21+
# Database specification.
22+
database_specification = {
23+
"type": ClassTypes.AIOSQLITE,
24+
"filename": f"{output_directory}/database.sqlite",
25+
}
26+
27+
# Test direct SQL access to the database.
28+
RevisionTester().main(
29+
database_specification,
30+
output_directory,
31+
)
32+
33+
34+
# ----------------------------------------------------------------------------------------
35+
class TestRevisionAiomysql:
36+
def test(self, logging_setup, output_directory):
37+
"""
38+
Tests the mysql implementation of Database.
39+
"""
40+
41+
host = Envvar("MYSQL_HOST", default="127.0.0.1")
42+
assert host.is_set
43+
port = Envvar("MYSQL_PORT", default=3306)
44+
assert port.is_set
45+
46+
# Database specification.
47+
database_specification = {
48+
"type": ClassTypes.AIOMYSQL,
49+
"type_specific_tbd": {
50+
"database_name": "dls_normsql_pytest",
51+
"host": "$MYSQL_HOST",
52+
"port": "$MYSQL_PORT",
53+
"username": "root",
54+
"password": "root",
55+
},
56+
}
57+
58+
# Test direct SQL access to the database.
59+
RevisionTester().main(
60+
database_specification,
61+
output_directory,
62+
)
63+
64+
65+
# ----------------------------------------------------------------------------------------
66+
class RevisionTester(BaseTester):
67+
"""
68+
Test direct SQL access to the database.
69+
"""
70+
71+
async def _main_coroutine(self, database_specification, output_directory):
72+
""" """
73+
74+
database_definition_object = MyDatabaseDefinition()
75+
databases = Databases()
76+
database1 = databases.build_object(
77+
database_specification, database_definition_object
78+
)
79+
80+
try:
81+
# Connect to database.
82+
await database1.connect(should_drop_database=True)
83+
84+
with pytest.raises(Exception):
85+
# We should not have the mytable2.
86+
records = await database1.query("SELECT * FROM my_table2")
87+
88+
# Back level the database's revision.
89+
await database1.update(
90+
Tablenames.REVISION, {RevisionFieldnames.NUMBER: "3"}, "1=1"
91+
)
92+
93+
# Apply the revisions up to the definition revision which is 4.
94+
await database1.apply_revisions()
95+
96+
# Now we should have the mytable2.
97+
records = await database1.query("SELECT * FROM my_table2")
98+
assert len(records) == 0
99+
100+
finally:
101+
# Disonnect from the databases... necessary to allow asyncio loop to exit.
102+
await database1.disconnect()

0 commit comments

Comments
 (0)