Skip to content

Commit 0f1bc8c

Browse files
committed
Implement transaction support in UnitOfWork with new methods for transaction management and repository handling
1 parent 3adf7a7 commit 0f1bc8c

File tree

6 files changed

+152
-21
lines changed

6 files changed

+152
-21
lines changed

src/CodeOfChaos.Types.UnitOfWork.Contracts/IUnitOfWork.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,27 @@ public interface IUnitOfWork : IAsyncDisposable {
1111
// -----------------------------------------------------------------------------------------------------------------
1212
// Methods
1313
// ----------------------------------------------------------------------------------------------------------------
14+
void SaveChanges();
1415
ValueTask SaveChangesAsync(CancellationToken ct = default);
16+
17+
bool TryCommitTransaction();
1518
ValueTask<bool> TryCommitTransactionAsync(CancellationToken ct = default);
19+
20+
bool TryCreateTransaction();
1621
ValueTask<bool> TryCreateTransactionAsync(CancellationToken ct = default);
22+
23+
bool TryRollbackTransaction();
1724
ValueTask<bool> TryRollbackTransactionAsync(CancellationToken ct = default);
25+
26+
bool TryRollbackToSavepoint(Guid id);
1827
ValueTask<bool> TryRollbackToSavepointAsync(Guid id, CancellationToken ct = default);
28+
29+
bool TryCreateSavepoint(Guid id);
1930
ValueTask<bool> TryCreateSavepointAsync(Guid id, CancellationToken ct = default);
2031

32+
TDbContext GetDbContext<TDbContext>() where TDbContext : DbContext;
2133
ValueTask<TDbContext> GetDbContextAsync<TDbContext>(CancellationToken ct = default) where TDbContext : DbContext;
2234

35+
TRepo GetRepository<TRepo>() where TRepo : class, IUnitOfWorkRepository;
2336
ValueTask<TRepo> GetRepositoryAsync<TRepo>(CancellationToken ct = default) where TRepo : class, IUnitOfWorkRepository;
2437
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
// ---------------------------------------------------------------------------------------------------------------------
22
// Imports
33
// ---------------------------------------------------------------------------------------------------------------------
4+
using System.Diagnostics.CodeAnalysis;
5+
46
namespace CodeOfChaos.Types.UnitOfWork;
57
// ---------------------------------------------------------------------------------------------------------------------
68
// Code
79
// ---------------------------------------------------------------------------------------------------------------------
810
public interface IUnitOfWorkFactory {
911
IUnitOfWork Create();
12+
13+
IUnitOfWork CreateWithTransaction();
1014
ValueTask<IUnitOfWork> CreateWithTransactionAsync(CancellationToken ct = default);
15+
16+
bool TryCreateWithTransaction([NotNullWhen(true)] out IUnitOfWork? unitOfWork);
1117
ValueTask<IUnitOfWork?> TryCreateWithTransactionAsync(CancellationToken ct = default);
1218
}

src/CodeOfChaos.Types.UnitOfWork/ServiceCollectionExtensions.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace CodeOfChaos.Types.UnitOfWork;
99
// Code
1010
// ---------------------------------------------------------------------------------------------------------------------
1111
public static class ServiceCollectionExtensions {
12+
1213
public static IServiceCollection AddUnitOfWork<TDbContext>(this IServiceCollection services) where TDbContext : DbContext {
1314
services.AddScoped<IUnitOfWorkFactory, UnitOfWorkFactory<TDbContext>>();
1415
services.AddScoped<IUnitOfWork>(static sp => sp.GetRequiredService<IUnitOfWorkFactory>().Create());
@@ -24,19 +25,13 @@ public static IServiceCollection AddUnitOfWork<TDbContext>(this IServiceCollecti
2425
}
2526

2627
public static IServiceCollection AddReadonlyUnitOfWork<TDbContext>(this IServiceCollection services) where TDbContext : DbContext, IReadonlyCapableDbContext {
27-
services.AddScoped<IUnitOfWorkFactory, UnitOfWorkFactory<TDbContext>>();
28-
services.AddScoped<IUnitOfWork>(static sp => sp.GetRequiredService<IUnitOfWorkFactory>().Create());
29-
3028
services.AddScoped<IReadonlyUnitOfWorkFactory, ReadonlyUnitOfWorkFactory<TDbContext>>();
3129
services.AddScoped<IReadonlyUnitOfWork>(static sp => sp.GetRequiredService<IReadonlyUnitOfWorkFactory>().Create());
3230

3331
return services;
3432
}
3533

3634
public static IServiceCollection AddReadonlyUnitOfWork<TDbContext>(this IServiceCollection services, string key) where TDbContext : DbContext, IReadonlyCapableDbContext {
37-
services.AddKeyedScoped<IUnitOfWorkFactory, UnitOfWorkFactory<TDbContext>>(key);
38-
services.AddKeyedScoped<IUnitOfWork>(key, implementationFactory: static (sp, k) => sp.GetRequiredKeyedService<IUnitOfWorkFactory>(k).Create());
39-
4035
services.AddKeyedScoped<IReadonlyUnitOfWorkFactory, ReadonlyUnitOfWorkFactory<TDbContext>>(key);
4136
services.AddKeyedScoped<IReadonlyUnitOfWork>(key, implementationFactory: static (sp, k) => sp.GetRequiredKeyedService<IReadonlyUnitOfWorkFactory>(k).Create());
4237

src/CodeOfChaos.Types.UnitOfWork/UnitOfWork.cs

Lines changed: 105 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using Microsoft.EntityFrameworkCore.Storage;
66
using Microsoft.Extensions.DependencyInjection;
77
using System.Collections.Concurrent;
8+
using System.Runtime.CompilerServices;
89

910
namespace CodeOfChaos.Types.UnitOfWork;
1011
// ---------------------------------------------------------------------------------------------------------------------
@@ -14,16 +15,18 @@ public class UnitOfWork<TDbContext>(IDbContextFactory<TDbContext> dbContextFacto
1415
private TDbContext? _dbContext;
1516
private IDbContextTransaction? _transaction;
1617
private readonly ConcurrentDictionary<Type, IUnitOfWorkRepository> AttachedRepositories = [];
17-
private readonly SemaphoreSlim _initLock = new(1, 1);
18-
19-
18+
private readonly SemaphoreSlim _initLockAsync = new(1, 1);
19+
private readonly SemaphoreSlim _transactionLockAsync = new(1, 1);
20+
private readonly Lock _transactionLock = new();
21+
private readonly Lazy<TDbContext> _lazyDb = new(dbContextFactory.CreateDbContext, LazyThreadSafetyMode.ExecutionAndPublication);
22+
2023
// -----------------------------------------------------------------------------------------------------------------
2124
// Methods
2225
// -----------------------------------------------------------------------------------------------------------------
2326
protected virtual async ValueTask<TDbContext> GetDbContextAsync(CancellationToken ct) {
2427
if (_dbContext != null) return _dbContext;
2528

26-
await _initLock.WaitAsync(ct);
29+
await _initLockAsync.WaitAsync(ct);
2730
try {
2831
// Double-check pattern
2932
if (_dbContext != null) return _dbContext;
@@ -32,7 +35,24 @@ protected virtual async ValueTask<TDbContext> GetDbContextAsync(CancellationToke
3235
return _dbContext;
3336
}
3437
finally {
35-
_initLock.Release();
38+
_initLockAsync.Release();
39+
}
40+
}
41+
42+
protected virtual TDbContext GetDbContext() => _lazyDb.Value;
43+
44+
public bool TryCreateTransaction() {
45+
if (_transaction != null) return false;
46+
47+
var dbContext = GetDbContext<TDbContext>();
48+
if (dbContext.Database.CurrentTransaction != null) {
49+
_transaction = dbContext.Database.CurrentTransaction;
50+
return true;
51+
}
52+
53+
lock (_transactionLock) {
54+
_transaction = dbContext.Database.BeginTransaction();
55+
return true;
3656
}
3757
}
3858

@@ -45,36 +65,61 @@ public virtual async ValueTask<bool> TryCreateTransactionAsync(CancellationToken
4565
return true;
4666
}
4767

48-
await _initLock.WaitAsync(ct);
68+
await _transactionLockAsync.WaitAsync(ct);
4969
try {
5070
_transaction = await dbContext.Database.BeginTransactionAsync(ct);
5171
return true;
5272
}
5373
finally {
54-
_initLock.Release();
74+
_transactionLockAsync.Release();
5575
}
5676
}
77+
78+
public void SaveChanges() {
79+
DbContext dbContext = GetDbContext<TDbContext>();
80+
dbContext.SaveChanges();
81+
}
5782

5883
public virtual async ValueTask SaveChangesAsync(CancellationToken ct = default) {
5984
DbContext dbContext = await GetDbContextAsync(ct);
6085
await dbContext.SaveChangesAsync(ct);
6186
}
87+
88+
public bool TryCommitTransaction() {
89+
if (_transaction == null) return false;
90+
91+
lock (_transactionLock) {
92+
_transaction.Commit();
93+
_transaction.Dispose();
94+
_transaction = null;
95+
return true;
96+
}
97+
}
6298

6399
public virtual async ValueTask<bool> TryCommitTransactionAsync(CancellationToken ct = default) {
64100
if (_transaction == null) return false;
65101

66-
await _initLock.WaitAsync(ct);
102+
await _transactionLockAsync.WaitAsync(ct);
67103
try {
68104
await _transaction.CommitAsync(ct);
69105
await _transaction.DisposeAsync();
70106
_transaction = null;
71107
return true;
72108
}
73109
finally {
74-
_initLock.Release();
110+
_transactionLockAsync.Release();
75111
}
76112
}
77113

114+
public bool TryRollbackTransaction() {
115+
if (_transaction == null) return false;
116+
117+
_transaction.Rollback();
118+
_transaction.Dispose();
119+
_transaction = null;
120+
121+
return true;
122+
}
78123

79124
public virtual async ValueTask<bool> TryRollbackTransactionAsync(CancellationToken ct = default) {
80125
if (_transaction == null) return false;
@@ -86,6 +131,14 @@ public virtual async ValueTask<bool> TryRollbackTransactionAsync(CancellationTok
86131
return true;
87132
}
88133

134+
public bool TryRollbackToSavepoint(Guid id) {
135+
if (_transaction == null) return false;
136+
if (!_transaction.SupportsSavepoints) return false;
137+
138+
_transaction.RollbackToSavepoint(id.ToString("N"));
139+
return true;
140+
}
141+
89142
public virtual async ValueTask<bool> TryRollbackToSavepointAsync(Guid id, CancellationToken ct = default) {
90143
if (_transaction == null) return false;
91144
if (!_transaction.SupportsSavepoints) return false;
@@ -94,6 +147,14 @@ public virtual async ValueTask<bool> TryRollbackToSavepointAsync(Guid id, Cancel
94147

95148
return true;
96149
}
150+
151+
public bool TryCreateSavepoint(Guid id) {
152+
if (_transaction == null) return false;
153+
if (!_transaction.SupportsSavepoints) return false;
154+
155+
_transaction.CreateSavepoint(id.ToString("N"));
156+
return true;
157+
}
97158

98159
public virtual async ValueTask<bool> TryCreateSavepointAsync(Guid id, CancellationToken ct = default) {
99160
if (_transaction == null) return false;
@@ -103,12 +164,34 @@ public virtual async ValueTask<bool> TryCreateSavepointAsync(Guid id, Cancellati
103164

104165
return true;
105166
}
167+
168+
public T GetDbContext<T>() where T : DbContext {
169+
if (typeof(T) != typeof(TDbContext)) throw new NotSupportedException($"DbContext type '{typeof(T)}' is not supported by this UnitOfWork.");
170+
171+
TDbContext dbContext = GetDbContext();
172+
173+
return Unsafe.As<TDbContext, T>(ref dbContext);
174+
}
106175

107176
public virtual async ValueTask<T> GetDbContextAsync<T>(CancellationToken ct = default) where T : DbContext {
108177
if (typeof(T) != typeof(TDbContext)) throw new NotSupportedException($"DbContext type '{typeof(T)}' is not supported by this UnitOfWork.");
109178

110179
TDbContext dbContext = await GetDbContextAsync(ct);
111-
return dbContext as T ?? throw new InvalidCastException($"Cannot cast DbContext of type '{dbContext.GetType()}' to '{typeof(T)}'");
180+
181+
return Unsafe.As<TDbContext, T>(ref dbContext);
182+
}
183+
184+
public TRepo GetRepository<TRepo>() where TRepo : class, IUnitOfWorkRepository {
185+
if (AttachedRepositories.TryGetValue(typeof(TRepo), out IUnitOfWorkRepository? cachedRepo) && cachedRepo is TRepo castedCachedRepo) return castedCachedRepo;
186+
187+
var repo = CreateAndAttachRepository<TRepo>();
188+
189+
AttachedRepositories.AddOrUpdate(
190+
typeof(TRepo),
191+
repo,
192+
(_, _) => repo
193+
);
194+
return repo;
112195
}
113196

114197
public virtual async ValueTask<TRepo> GetRepositoryAsync<TRepo>(CancellationToken ct = default) where TRepo : class, IUnitOfWorkRepository {
@@ -125,6 +208,14 @@ public virtual async ValueTask<TRepo> GetRepositoryAsync<TRepo>(CancellationToke
125208
return repo;
126209
}
127210

211+
private TRepo CreateAndAttachRepository<TRepo>() where TRepo : class, IUnitOfWorkRepository {
212+
var repo = serviceScope.ServiceProvider.GetRequiredService<TRepo>();
213+
if (repo is not UnitOfWorkRepository<TDbContext> castedRepo) throw new InvalidCastException($"Cannot cast repository of type '{repo.GetType()}' to '{typeof(TRepo)}'");
214+
215+
castedRepo.Attach(this);
216+
return repo;
217+
}
218+
128219
private async ValueTask<TRepo> CreateAndAttachRepositoryAsync<TRepo>(CancellationToken ct = default) where TRepo : class, IUnitOfWorkRepository {
129220
var repo = serviceScope.ServiceProvider.GetRequiredService<TRepo>();
130221
if (repo is not UnitOfWorkRepository<TDbContext> castedRepo) throw new InvalidCastException($"Cannot cast repository of type '{repo.GetType()}' to '{typeof(TRepo)}'");
@@ -135,12 +226,12 @@ private async ValueTask<TRepo> CreateAndAttachRepositoryAsync<TRepo>(Cancellatio
135226

136227
public virtual async ValueTask DisposeAsync() {
137228
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
138-
if (_initLock == null) {
229+
if (_initLockAsync == null) {
139230
GC.SuppressFinalize(this);
140231
return;
141232
}
142233

143-
await _initLock.WaitAsync();
234+
await _initLockAsync.WaitAsync();
144235
try {
145236
if (_transaction != null) {
146237
await TryRollbackTransactionAsync();
@@ -162,8 +253,8 @@ public virtual async ValueTask DisposeAsync() {
162253
serviceScope.Dispose();
163254
}
164255
finally {
165-
_initLock.Release();
166-
_initLock.Dispose();
256+
_initLockAsync.Release();
257+
_initLockAsync.Dispose();
167258
GC.SuppressFinalize(this);
168259
}
169260
}

src/CodeOfChaos.Types.UnitOfWork/UnitOfWorkFactory.cs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
using Microsoft.EntityFrameworkCore;
55
using Microsoft.Extensions.DependencyInjection;
66
using Microsoft.Extensions.Logging;
7+
using System.Diagnostics.CodeAnalysis;
78

89
namespace CodeOfChaos.Types.UnitOfWork;
910
// ---------------------------------------------------------------------------------------------------------------------
@@ -18,7 +19,19 @@ public IUnitOfWork Create() {
1819
// Because our factory doesn't create the actual dbcontext, yet we are safe, and we can just inject it downwards.
1920
return new UnitOfWork<TDbContext>(dbContextFactory, scope);
2021
}
21-
22+
23+
public IUnitOfWork CreateWithTransaction() {
24+
IUnitOfWork unitOfWork = Create();
25+
26+
// ReSharper disable once InvertIf
27+
if (!unitOfWork.TryCreateTransaction()) {
28+
logger.LogError("Failed to create transaction for new unit of work.");
29+
throw new Exception("Failed to create transaction");
30+
}
31+
32+
return unitOfWork;
33+
}
34+
2235
public async ValueTask<IUnitOfWork> CreateWithTransactionAsync(CancellationToken ct = default) {
2336
IUnitOfWork unitOfWork = Create();
2437

@@ -30,6 +43,17 @@ public async ValueTask<IUnitOfWork> CreateWithTransactionAsync(CancellationToken
3043

3144
return unitOfWork;
3245
}
46+
47+
public bool TryCreateWithTransaction([NotNullWhen(true)] out IUnitOfWork? unitOfWork) {
48+
unitOfWork = Create();
49+
50+
// ReSharper disable once InvertIf
51+
if (!unitOfWork.TryCreateTransaction()) {
52+
unitOfWork = null;
53+
return false;
54+
}
55+
return true;
56+
}
3357

3458
public async ValueTask<IUnitOfWork?> TryCreateWithTransactionAsync(CancellationToken ct = default) {
3559
IUnitOfWork unitOfWork = Create();

src/CodeOfChaos.Types.UnitOfWork/UnitOfWorkRepository.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ public abstract class UnitOfWorkRepository<TDbContext> : IUnitOfWorkRepository
1616
// -----------------------------------------------------------------------------------------------------------------
1717
// Methods
1818
// -----------------------------------------------------------------------------------------------------------------
19+
20+
internal void Attach(IUnitOfWork unitOfWork) => DbContext = unitOfWork.GetDbContext<TDbContext>();
1921
internal async ValueTask AttachAsync(IUnitOfWork unitOfWork, CancellationToken ct = default) => DbContext = await unitOfWork.GetDbContextAsync<TDbContext>(ct);
2022
internal void Detach() => DbContext = null;// Remove the reference to the DbContext
2123

0 commit comments

Comments
 (0)