Skip to content

Commit 452cafb

Browse files
author
David Erb
committed
adds connect_lock
1 parent 42fd130 commit 452cafb

File tree

1 file changed

+73
-63
lines changed

1 file changed

+73
-63
lines changed

src/dls_normsql/aiosqlite.py

Lines changed: 73 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
logger = logging.getLogger(__name__)
2525

26+
connect_lock = asyncio.Lock()
2627

2728
# ----------------------------------------------------------------------------------------
2829
def sqlite_regexp_callback(pattern, input):
@@ -71,76 +72,79 @@ async def connect(self):
7172
Connect to database at filename given in constructor.
7273
"""
7374

74-
should_create_schemas = False
75+
async with connect_lock:
76+
should_create_schemas = False
7577

76-
# File doesn't exist yet?
77-
if not os.path.isfile(self.__filename):
78-
# Create directory for the file.
79-
await self.__create_directory(self.__filename)
80-
# After connection, we must create the schemas.
81-
should_create_schemas = True
78+
# File doesn't exist yet?
79+
if not os.path.isfile(self.__filename):
80+
# Create directory for the file.
81+
await self.__create_directory(self.__filename)
82+
# After connection, we must create the schemas.
83+
should_create_schemas = True
8284

83-
logger.debug(f"connecting to {self.__filename}")
85+
logger.debug(f"connecting to {self.__filename}")
8486

85-
self.__connection = await aiosqlite.connect(self.__filename)
86-
self.__connection.row_factory = aiosqlite.Row
87+
self.__connection = await aiosqlite.connect(self.__filename)
88+
self.__connection.row_factory = aiosqlite.Row
8789

88-
# rows = await self.query("PRAGMA journal_mode", why="query journal mode")
89-
# logger.debug(f"journal mode rows {json.dumps(rows)}")
90+
# rows = await self.query("PRAGMA journal_mode", why="query journal mode")
91+
# logger.debug(f"journal mode rows {json.dumps(rows)}")
9092

91-
# rows = await self.query("PRAGMA journal_mode=OFF", why="turn OFF journal mode")
92-
# logger.debug(f"journal mode OFF rows {json.dumps(rows)}")
93+
# rows = await self.query("PRAGMA journal_mode=OFF", why="turn OFF journal mode")
94+
# logger.debug(f"journal mode OFF rows {json.dumps(rows)}")
9395

94-
# rows = await self.query("PRAGMA journal_mode", why="query journal mode")
95-
# logger.debug(f"journal mode rows {json.dumps(rows)}")
96+
# rows = await self.query("PRAGMA journal_mode", why="query journal mode")
97+
# logger.debug(f"journal mode rows {json.dumps(rows)}")
9698

97-
# rows = await self.query("SELECT * from mainTable", why="main table check")
99+
# rows = await self.query("SELECT * from mainTable", why="main table check")
98100

99-
await self.__connection.create_function("regexp", 2, sqlite_regexp_callback)
100-
logger.debug("created regexp function")
101+
await self.__connection.create_function("regexp", 2, sqlite_regexp_callback)
102+
logger.debug("created regexp function")
101103

102-
await self.add_table_definitions()
104+
await self.add_table_definitions()
103105

104-
if should_create_schemas:
105-
await self.create_schemas()
106-
await self.insert(Tablenames.REVISION, [{"number": self.LATEST_REVISION}])
107-
# TODO: Set permission on sqlite file from configuration.
108-
os.chmod(self.__filename, 0o666)
109-
else:
110-
try:
111-
records = await self.query(
112-
f"SELECT number FROM {Tablenames.REVISION}",
113-
why="get database revision",
106+
if should_create_schemas:
107+
await self.create_schemas()
108+
await self.insert(
109+
Tablenames.REVISION, [{"number": self.LATEST_REVISION}]
114110
)
115-
if len(records) == 0:
111+
# TODO: Set permission on sqlite file from configuration.
112+
os.chmod(self.__filename, 0o666)
113+
else:
114+
try:
115+
records = await self.query(
116+
f"SELECT number FROM {Tablenames.REVISION}",
117+
why="get database revision",
118+
)
119+
if len(records) == 0:
120+
old_revision = 0
121+
else:
122+
old_revision = records[0]["number"]
123+
except Exception as exception:
124+
logger.warning(
125+
f"could not get revision, presuming legacy database with no table: {exception}"
126+
)
116127
old_revision = 0
117-
else:
118-
old_revision = records[0]["number"]
119-
except Exception as exception:
120-
logger.warning(
121-
f"could not get revision, presuming legacy database with no table: {exception}"
122-
)
123-
old_revision = 0
124128

125-
if old_revision < self.LATEST_REVISION:
126-
logger.debug(
127-
f"need to update old revision {old_revision}"
128-
f" to latest revision {self.LATEST_REVISION}"
129-
)
130-
for revision in range(old_revision, self.LATEST_REVISION):
131-
logger.debug(f"updating to revision {revision+1}")
132-
await self.apply_revision(revision + 1)
133-
await self.update(
134-
Tablenames.REVISION,
135-
{"number": self.LATEST_REVISION},
136-
"1 = 1",
137-
why="update database revision",
138-
)
129+
if old_revision < self.LATEST_REVISION:
130+
logger.debug(
131+
f"need to update old revision {old_revision}"
132+
f" to latest revision {self.LATEST_REVISION}"
133+
)
134+
for revision in range(old_revision, self.LATEST_REVISION):
135+
logger.debug(f"updating to revision {revision+1}")
136+
await self.apply_revision(revision + 1)
137+
await self.update(
138+
Tablenames.REVISION,
139+
{"number": self.LATEST_REVISION},
140+
"1 = 1",
141+
why="update database revision",
142+
)
139143

140-
# Emit the name of the database file for positive confirmation on console.
141-
logger.info(
142-
f"{callsign(self)} database file is {self.__filename} revision {self.LATEST_REVISION}"
143-
)
144+
# Emit the name of the database file for positive confirmation on console.
145+
logger.info(
146+
f"{callsign(self)} database file is {self.__filename} revision {self.LATEST_REVISION}"
147+
)
144148

145149
# ----------------------------------------------------------------------------------------
146150
async def apply_revision(self, revision):
@@ -233,9 +237,11 @@ async def create_table(self, table, should_commit: Optional[bool] = True):
233237
% (table.name, field_name, table.name, field_name)
234238
)
235239

236-
await self.__connection.execute(
237-
"CREATE TABLE %s(%s)" % (table.name, ", ".join(fields_sql))
238-
)
240+
sql = "CREATE TABLE %s\n(%s)" % (table.name, ",\n ".join(fields_sql))
241+
242+
logger.info("\n%s\n%s" % (sql, "\n".join(indices_sql)))
243+
244+
await self.__connection.execute(sql)
239245

240246
for sql in indices_sql:
241247
await self.__connection.execute(sql)
@@ -250,7 +256,7 @@ async def insert(
250256
rows,
251257
why=None,
252258
should_commit: Optional[bool] = True,
253-
):
259+
) -> None:
254260
"""
255261
Insert one or more rows.
256262
Each row is a dictionary.
@@ -270,6 +276,7 @@ async def insert(
270276

271277
values_rows = []
272278

279+
logger.info(f"inserting {table.name} fields {list(table.fields.keys())}")
273280
insertable_fields = []
274281
for field in table.fields:
275282
# The first row is expected to define the keys for all rows inserted.
@@ -283,10 +290,13 @@ async def insert(
283290
for row in rows:
284291
values_row = []
285292
for field in table.fields:
286-
if field in row:
293+
if field == CommonFieldnames.CREATED_ON:
294+
created_on = row.get(field)
295+
if created_on is None:
296+
created_on = now
297+
values_row.append(created_on)
298+
elif field in row:
287299
values_row.append(row[field])
288-
elif field == CommonFieldnames.CREATED_ON:
289-
values_row.append(row.get(field, now))
290300
values_rows.append(values_row)
291301

292302
sql = "INSERT INTO %s (%s) VALUES (%s)" % (

0 commit comments

Comments
 (0)