Skip to content

Commit 57c723a

Browse files
authored
add locks around transaction closing (#497)
1 parent b1c4188 commit 57c723a

File tree

1 file changed

+51
-29
lines changed

1 file changed

+51
-29
lines changed

SDMeta/Cache/SqliteDataSource.cs

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
1-
using Dapper;
1+
using Dapper;
22
using Microsoft.Data.Sqlite;
33
using Microsoft.Extensions.Logging;
44
using System;
55
using System.Collections.Generic;
66
using System.Linq;
7+
using System.Threading;
78

89
namespace SDMeta.Cache
910
{
@@ -12,6 +13,7 @@ public partial class SqliteDataSource : IImageFileDataSource
1213
const string TableName = "PngFilesv2";
1314
private string FTSTableName = $"FTS5{TableName}";
1415
private SqliteTransaction? transaction;
16+
private readonly Lock transactionLock = new();
1517

1618
private readonly string[] columns =
1719
[
@@ -71,17 +73,19 @@ private SqliteConnection GetConnection()
7173
return connection;
7274
}
7375

74-
private T ExecuteOnConnection<T>(Func<SqliteConnection, T> func)
76+
private T ExecuteOnConnection<T>(Func<SqliteConnection, SqliteTransaction?, T> func)
7577
{
76-
if (this.transaction?.Connection != null)
78+
lock (transactionLock)
7779
{
78-
return func(transaction.Connection);
79-
}
80-
else
81-
{
82-
using var connection = GetConnection();
83-
return func(connection);
80+
var currentTransaction = this.transaction;
81+
if (currentTransaction?.Connection != null)
82+
{
83+
return func(currentTransaction.Connection, currentTransaction);
84+
}
8485
}
86+
87+
using var connection = GetConnection();
88+
return func(connection, null);
8589
}
8690

8791
private string GetInsertSql()
@@ -102,11 +106,11 @@ public void Initialize()
102106
var tabledef = GetTableDefinition();
103107

104108
// Setup table if absent https://learn.microsoft.com/en-us/dotnet/standard/data/sqlite/types
105-
ExecuteOnConnection(connection => connection.Execute(@$"CREATE TABLE IF NOT EXISTS {TableName} (
109+
ExecuteOnConnection((connection, _) => connection.Execute(@$"CREATE TABLE IF NOT EXISTS {TableName} (
106110
{tabledef.Select(p => $"{p.Column} {p.DataType}{(p.IsPrimaryKey ? " PRIMARY KEY" : "")}").ToCommaSeparated()}
107111
);"));
108112

109-
ExecuteOnConnection(connection => connection.Execute(@$"CREATE VIRTUAL TABLE IF NOT EXISTS {FTSTableName} USING fts5({ftscolumns.ToCommaSeparated()});"));
113+
ExecuteOnConnection((connection, _) => connection.Execute(@$"CREATE VIRTUAL TABLE IF NOT EXISTS {FTSTableName} USING fts5({ftscolumns.ToCommaSeparated()});"));
110114
logger.LogInformation("Initalization completed");
111115
}
112116

@@ -135,7 +139,7 @@ public IEnumerable<ImageFileSummary> Query(QueryParams queryParams)
135139
modelHash = queryParams.ModelFilter?.ModelHash,
136140
};
137141

138-
var reader = ExecuteOnConnection(connection =>
142+
var reader = ExecuteOnConnection((connection, _) =>
139143
connection.Query<ImageFileSummary>(sql, param)
140144
);
141145
return reader;
@@ -212,7 +216,7 @@ private static string BuildOrderByClause(QuerySortBy querySort)
212216

213217
public ImageFile? ReadImageFile(string realFileName)
214218
{
215-
var reader = ExecuteOnConnection(connection => connection.QueryFirstOrDefault<DataRow>(
219+
var reader = ExecuteOnConnection((connection, _) => connection.QueryFirstOrDefault<DataRow>(
216220
$@"SELECT *
217221
FROM {TableName}
218222
WHERE FileName = @FileName
@@ -230,10 +234,10 @@ private static string BuildOrderByClause(QuerySortBy querySort)
230234

231235
public void WriteImageFile(ImageFile info)
232236
{
233-
ExecuteOnConnection(connection => connection.Execute(
237+
ExecuteOnConnection((connection, tx) => connection.Execute(
234238
insertSql.Value,
235239
FromModel(info),
236-
this.transaction
240+
tx
237241
));
238242
}
239243

@@ -257,25 +261,42 @@ private DataRow FromModel(ImageFile info)
257261

258262
public void BeginTransaction()
259263
{
260-
this.transaction ??= GetConnection().BeginTransaction();
264+
lock (transactionLock)
265+
{
266+
this.transaction ??= GetConnection().BeginTransaction();
267+
}
261268
}
262269

263270
public void CommitTransaction()
264271
{
265-
if (this.transaction != null)
272+
SqliteTransaction? transactionToCommit;
273+
274+
lock (transactionLock)
266275
{
267-
var connection = this.transaction.Connection;
268-
this.transaction.Commit();
269-
this.transaction.Dispose();
276+
transactionToCommit = this.transaction;
270277
this.transaction = null;
271-
connection?.Close();
278+
}
279+
280+
if (transactionToCommit == null)
281+
{
282+
return;
283+
}
284+
285+
try
286+
{
287+
transactionToCommit.Commit();
288+
}
289+
finally
290+
{
291+
var connection = transactionToCommit.Connection;
292+
transactionToCommit.Dispose();
272293
connection?.Dispose();
273294
}
274295
}
275296

276297
public IEnumerable<ModelSummary> GetModelSummaryList()
277298
{
278-
var reader = ExecuteOnConnection(connection => connection.Query<ModelSummary>(
299+
var reader = ExecuteOnConnection((connection, _) => connection.Query<ModelSummary>(
279300
$@"SELECT Model, ModelHash, Count(*) as Count
280301
FROM {TableName}
281302
GROUP BY Model, ModelHash
@@ -287,7 +308,7 @@ ORDER BY 3 DESC"
287308

288309
public IEnumerable<string> GetAllFilenames()
289310
{
290-
var reader = ExecuteOnConnection(connection => connection.Query<string>(
311+
var reader = ExecuteOnConnection((connection, _) => connection.Query<string>(
291312
$@"SELECT Filename
292313
FROM {TableName}
293314
WHERE [Exists] = 1"
@@ -298,28 +319,28 @@ public IEnumerable<string> GetAllFilenames()
298319

299320
public void Truncate()
300321
{
301-
ExecuteOnConnection(connection => connection.Execute($"DELETE FROM {TableName}"));
302-
ExecuteOnConnection(connection => connection.Execute($"DELETE FROM {FTSTableName}"));
322+
ExecuteOnConnection((connection, _) => connection.Execute($"DELETE FROM {TableName}"));
323+
ExecuteOnConnection((connection, _) => connection.Execute($"DELETE FROM {FTSTableName}"));
303324
}
304325

305326
public void PostUpdateProcessing()
306327
{
307-
ExecuteOnConnection(connection =>
328+
ExecuteOnConnection((connection, tx) =>
308329
connection.Execute(
309330
$@"INSERT INTO {FTSTableName} (FileName, Prompt, PromptFormat, Version)
310331
SELECT FileName, Prompt, PromptFormat, Version FROM {TableName}
311332
WHERE FileName NOT IN (SELECT FileName from {FTSTableName})",
312-
this.transaction));
333+
tx));
313334

314-
ExecuteOnConnection(connection =>
335+
ExecuteOnConnection((connection, tx) =>
315336
connection.Execute(
316337
$@"UPDATE {FTSTableName} SET
317338
Prompt = p.Prompt,
318339
PromptFormat = p.PromptFormat,
319340
Version = p.Version
320341
FROM {TableName} p
321342
WHERE {FTSTableName}.FileName = p.FileName and {FTSTableName}.Version != p.Version",
322-
this.transaction));
343+
tx));
323344
}
324345
}
325346

@@ -333,3 +354,4 @@ public static string ToCommaSeparated(this IEnumerable<string> list)
333354

334355
internal record struct ColumnDefinition(string Column, string Parameter, string DataType, bool IsPrimaryKey);
335356
}
357+

0 commit comments

Comments
 (0)