Skip to content

Commit 692d22c

Browse files
committed
Add logging support to UnitOfWork and ReadonlyUnitOfWork classes
Integrated `ILogger` to `UnitOfWork` and `ReadonlyUnitOfWork` for improved error handling and debugging. Updated related factories and tests accordingly.
1 parent 23229f9 commit 692d22c

File tree

8 files changed

+158
-66
lines changed

8 files changed

+158
-66
lines changed

src/CodeOfChaos.Extensions.EntityFrameworkCore/UnitOfWork/ReadonlyUnitOfWork.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22
// Imports
33
// ---------------------------------------------------------------------------------------------------------------------
44
using Microsoft.Extensions.DependencyInjection;
5+
using Microsoft.Extensions.Logging;
56

67
namespace Microsoft.EntityFrameworkCore;
78
// ---------------------------------------------------------------------------------------------------------------------
89
// Code
910
// ---------------------------------------------------------------------------------------------------------------------
1011
public class ReadonlyUnitOfWork<TDbContext>(
1112
IDbContextFactory<TDbContext> dbContextFactory,
12-
AsyncServiceScope serviceScope
13-
) : UnitOfWork<TDbContext>(dbContextFactory, serviceScope), IReadonlyUnitOfWork<TDbContext>
13+
AsyncServiceScope serviceScope,
14+
ILogger<ReadonlyUnitOfWork<TDbContext>> logger
15+
) : UnitOfWork<TDbContext>(dbContextFactory, serviceScope, logger), IReadonlyUnitOfWork<TDbContext>
1416
where TDbContext : DbContext, IReadonlyCapableDbContext {
1517

1618
public async override ValueTask<TDbContext> GetDbContextAsync(CancellationToken ct) {

src/CodeOfChaos.Extensions.EntityFrameworkCore/UnitOfWork/ReadonlyUnitOfWorkFactory.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Imports
33
// ---------------------------------------------------------------------------------------------------------------------
44
using Microsoft.Extensions.DependencyInjection;
5+
using Microsoft.Extensions.Logging;
56

67
namespace Microsoft.EntityFrameworkCore;
78

@@ -10,11 +11,12 @@ namespace Microsoft.EntityFrameworkCore;
1011
// ---------------------------------------------------------------------------------------------------------------------
1112
public class ReadonlyUnitOfWorkFactory<TDbContext>(
1213
IDbContextFactory<TDbContext> dbContextFactory,
13-
IServiceProvider provider
14+
IServiceProvider provider,
15+
ILoggerFactory loggerFactory
1416
) : IReadonlyUnitOfWorkFactory<TDbContext> where TDbContext : DbContext, IReadonlyCapableDbContext {
1517

1618
public IReadonlyUnitOfWork<TDbContext> Create() {
1719
AsyncServiceScope scope = provider.CreateAsyncScope();
18-
return new ReadonlyUnitOfWork<TDbContext>(dbContextFactory, scope);
20+
return new ReadonlyUnitOfWork<TDbContext>(dbContextFactory, scope, loggerFactory.CreateLogger<ReadonlyUnitOfWork<TDbContext>>());
1921
}
2022
}

src/CodeOfChaos.Extensions.EntityFrameworkCore/UnitOfWork/UnitOfWork.cs

Lines changed: 129 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,29 @@
33
// ---------------------------------------------------------------------------------------------------------------------
44
using Microsoft.EntityFrameworkCore.Storage;
55
using Microsoft.Extensions.DependencyInjection;
6+
using Microsoft.Extensions.Logging;
67
using System.Collections.Concurrent;
78
using System.Runtime.CompilerServices;
89

910
namespace Microsoft.EntityFrameworkCore;
1011
// ---------------------------------------------------------------------------------------------------------------------
1112
// Code
1213
// ---------------------------------------------------------------------------------------------------------------------
13-
public class UnitOfWork<TDbContext>(IDbContextFactory<TDbContext> dbContextFactory, AsyncServiceScope serviceScope) : IUnitOfWork<TDbContext> where TDbContext : DbContext {
14+
public class UnitOfWork<TDbContext>(IDbContextFactory<TDbContext> dbContextFactory, AsyncServiceScope serviceScope, ILogger<UnitOfWork<TDbContext>> logger) : IUnitOfWork<TDbContext> where TDbContext : DbContext {
1415
private TDbContext? _dbContext;
1516
private IDbContextTransaction? _transaction;
1617
private readonly ConcurrentDictionary<Type, IUnitOfWorkRepository> AttachedRepositories = [];
1718
private readonly SemaphoreSlim _initLockAsync = new(1, 1);
1819
private readonly SemaphoreSlim _transactionLockAsync = new(1, 1);
1920
private readonly Lock _transactionLock = new();
2021
private readonly Lazy<TDbContext> _lazyDb = new(dbContextFactory.CreateDbContext, LazyThreadSafetyMode.ExecutionAndPublication);
21-
22+
2223
internal bool IsDisposed { get; private set; }
23-
24+
2425
// -----------------------------------------------------------------------------------------------------------------
2526
// Methods
2627
// -----------------------------------------------------------------------------------------------------------------
27-
public virtual async ValueTask<TDbContext> GetDbContextAsync(CancellationToken ct) {
28+
public async virtual ValueTask<TDbContext> GetDbContextAsync(CancellationToken ct) {
2829
if (_dbContext != null) return _dbContext;
2930

3031
await _initLockAsync.WaitAsync(ct);
@@ -44,7 +45,7 @@ public virtual async ValueTask<TDbContext> GetDbContextAsync(CancellationToken c
4445

4546
public virtual bool TryCreateTransaction() {
4647
if (_transaction != null) return false;
47-
48+
4849
var dbContext = GetDbContext<TDbContext>();
4950
if (dbContext.Database.CurrentTransaction != null) {
5051
_transaction = dbContext.Database.CurrentTransaction;
@@ -56,8 +57,8 @@ public virtual bool TryCreateTransaction() {
5657
return true;
5758
}
5859
}
59-
60-
public virtual async ValueTask<bool> TryCreateTransactionAsync(CancellationToken ct = default) {
60+
61+
public async virtual ValueTask<bool> TryCreateTransactionAsync(CancellationToken ct = default) {
6162
if (_transaction != null) return false;
6263

6364
TDbContext dbContext = await GetDbContextAsync(ct);
@@ -71,6 +72,11 @@ public virtual async ValueTask<bool> TryCreateTransactionAsync(CancellationToken
7172
_transaction = await dbContext.Database.BeginTransactionAsync(ct);
7273
return true;
7374
}
75+
catch (Exception ex) {
76+
logger.LogError(ex, "Failed to begin transaction.");
77+
_transaction = null;
78+
return false;
79+
}
7480
finally {
7581
_transactionLockAsync.Release();
7682
}
@@ -80,24 +86,33 @@ public virtual void SaveChanges() {
8086
DbContext dbContext = GetDbContext<TDbContext>();
8187
dbContext.SaveChanges();
8288
}
83-
84-
public virtual async ValueTask SaveChangesAsync(CancellationToken ct = default) {
89+
90+
public async virtual ValueTask SaveChangesAsync(CancellationToken ct = default) {
8591
DbContext dbContext = await GetDbContextAsync(ct);
8692
await dbContext.SaveChangesAsync(ct);
8793
}
88-
94+
8995
public virtual bool TryCommitTransaction() {
9096
if (_transaction == null) return false;
9197

9298
lock (_transactionLock) {
93-
_transaction.Commit();
94-
_transaction.Dispose();
95-
_transaction = null;
96-
return true;
99+
try {
100+
_transaction.Commit();
101+
_transaction.Dispose();
102+
_transaction = null;
103+
return true;
104+
}
105+
catch (Exception ex) {
106+
logger.LogError(ex, "Failed to commit transaction.");
107+
lock (_transactionLock) {
108+
_transaction = null;
109+
}
110+
return false;
111+
}
97112
}
98113
}
99114

100-
public virtual async ValueTask<bool> TryCommitTransactionAsync(CancellationToken ct = default) {
115+
public async virtual ValueTask<bool> TryCommitTransactionAsync(CancellationToken ct = default) {
101116
if (_transaction == null) return false;
102117

103118
await _transactionLockAsync.WaitAsync(ct);
@@ -107,103 +122,162 @@ public virtual async ValueTask<bool> TryCommitTransactionAsync(CancellationToken
107122
_transaction = null;
108123
return true;
109124
}
125+
catch (Exception ex) {
126+
logger.LogError(ex, "Failed to commit transaction.");
127+
_transaction = null;
128+
return false;
129+
}
110130
finally {
111131
_transactionLockAsync.Release();
112132
}
113133
}
114134

115135
public virtual bool TryRollbackTransaction() {
116136
if (_transaction == null) return false;
117-
118-
_transaction.Rollback();
119-
_transaction.Dispose();
120-
_transaction = null;
121-
122-
return true;
137+
138+
lock (_transactionLock) {
139+
try {
140+
_transaction.Rollback();
141+
_transaction.Dispose();
142+
_transaction = null;
143+
return true;
144+
}
145+
catch (Exception ex) {
146+
logger.LogError(ex, "Failed to rollback transaction.");
147+
lock (_transactionLock) {
148+
_transaction = null;
149+
}
150+
return false;
151+
}
152+
}
123153
}
124154

125-
public virtual async ValueTask<bool> TryRollbackTransactionAsync(CancellationToken ct = default) {
155+
public async virtual ValueTask<bool> TryRollbackTransactionAsync(CancellationToken ct = default) {
126156
if (_transaction == null) return false;
127157

128-
await _transaction.RollbackAsync(ct);
129-
await _transaction.DisposeAsync();
130-
_transaction = null;
131-
132-
return true;
158+
await _transactionLockAsync.WaitAsync(ct);
159+
try {
160+
await _transaction.RollbackAsync(ct);
161+
await _transaction.DisposeAsync();
162+
_transaction = null;
163+
return true;
164+
}
165+
catch (Exception ex) {
166+
logger.LogError(ex, "Failed to rollback transaction.");
167+
_transaction = null;
168+
return false;
169+
}
170+
finally {
171+
_transactionLockAsync.Release();
172+
}
133173
}
134174

135175
public virtual bool TryRollbackToSavepoint(Guid id) {
136176
if (_transaction == null) return false;
137177
if (!_transaction.SupportsSavepoints) return false;
138-
139-
_transaction.RollbackToSavepoint(id.ToString("N"));
140-
return true;
178+
179+
lock (_transactionLock) {
180+
try {
181+
_transaction.RollbackToSavepoint(id.ToString("N"));
182+
return true;
183+
}
184+
catch (Exception ex) {
185+
logger.LogError(ex, "Failed to rollback to savepoint.");
186+
return false;
187+
}
188+
}
141189
}
142190

143-
public virtual async ValueTask<bool> TryRollbackToSavepointAsync(Guid id, CancellationToken ct = default) {
191+
public async virtual ValueTask<bool> TryRollbackToSavepointAsync(Guid id, CancellationToken ct = default) {
144192
if (_transaction == null) return false;
145193
if (!_transaction.SupportsSavepoints) return false;
146194

147-
await _transaction.RollbackToSavepointAsync(id.ToString("N"), ct);
148-
149-
return true;
195+
await _transactionLockAsync.WaitAsync(ct);
196+
try {
197+
await _transaction.RollbackToSavepointAsync(id.ToString("N"), ct);
198+
return true;
199+
}
200+
catch (Exception ex) {
201+
logger.LogError(ex, "Failed to rollback to savepoint.");
202+
return false;
203+
}
204+
finally {
205+
_transactionLockAsync.Release();
206+
}
150207
}
151-
208+
152209
public virtual bool TryCreateSavepoint(Guid id) {
153210
if (_transaction == null) return false;
154211
if (!_transaction.SupportsSavepoints) return false;
155-
156-
_transaction.CreateSavepoint(id.ToString("N"));
157-
return true;
212+
213+
lock (_transactionLock) {
214+
try {
215+
_transaction.CreateSavepoint(id.ToString("N"));
216+
return true;
217+
}
218+
catch (Exception ex) {
219+
logger.LogError(ex, "Failed to create savepoint.");
220+
return false;
221+
}
222+
}
158223
}
159224

160-
public virtual async ValueTask<bool> TryCreateSavepointAsync(Guid id, CancellationToken ct = default) {
225+
public async virtual ValueTask<bool> TryCreateSavepointAsync(Guid id, CancellationToken ct = default) {
161226
if (_transaction == null) return false;
162227
if (!_transaction.SupportsSavepoints) return false;
163228

164-
await _transaction.CreateSavepointAsync(id.ToString("N"), ct);
165-
166-
return true;
229+
await _transactionLockAsync.WaitAsync(ct);
230+
try {
231+
await _transaction.CreateSavepointAsync(id.ToString("N"), ct);
232+
return true;
233+
}
234+
catch (Exception ex) {
235+
logger.LogError(ex, "Failed to create savepoint.");
236+
return false;
237+
}
238+
finally {
239+
_transactionLockAsync.Release();
240+
}
167241
}
168-
242+
169243
public virtual T GetDbContext<T>() where T : DbContext {
170244
if (typeof(T) != typeof(TDbContext)) throw new NotSupportedException($"DbContext type '{typeof(T)}' is not supported by this UnitOfWork.");
171-
245+
172246
TDbContext dbContext = GetDbContext();
173-
247+
174248
return Unsafe.As<TDbContext, T>(ref dbContext);
175249
}
176250

177-
public virtual async ValueTask<T> GetDbContextAsync<T>(CancellationToken ct = default) where T : DbContext {
251+
public async virtual ValueTask<T> GetDbContextAsync<T>(CancellationToken ct = default) where T : DbContext {
178252
if (typeof(T) != typeof(TDbContext)) throw new NotSupportedException($"DbContext type '{typeof(T)}' is not supported by this UnitOfWork.");
179253

180254
TDbContext dbContext = await GetDbContextAsync(ct);
181-
255+
182256
return Unsafe.As<TDbContext, T>(ref dbContext);
183257
}
184258

185259
public virtual TRepo GetRepository<TRepo>() where TRepo : class, IUnitOfWorkRepository {
186260
if (AttachedRepositories.TryGetValue(typeof(TRepo), out IUnitOfWorkRepository? cachedRepo) && cachedRepo is TRepo castedCachedRepo) return castedCachedRepo;
187-
261+
188262
var repo = CreateAndAttachRepository<TRepo>();
189-
263+
190264
AttachedRepositories.AddOrUpdate(
191265
typeof(TRepo),
192-
repo,
266+
repo,
193267
(_, _) => repo
194268
);
195269
return repo;
196270
}
197271

198-
public virtual async ValueTask<TRepo> GetRepositoryAsync<TRepo>(CancellationToken ct = default) where TRepo : class, IUnitOfWorkRepository {
272+
public async virtual ValueTask<TRepo> GetRepositoryAsync<TRepo>(CancellationToken ct = default) where TRepo : class, IUnitOfWorkRepository {
199273
if (AttachedRepositories.TryGetValue(typeof(TRepo), out IUnitOfWorkRepository? cachedRepo) && cachedRepo is TRepo castedCachedRepo) return castedCachedRepo;
200274

201275
// Cache miss so we create a new instance
202276
var repo = await CreateAndAttachRepositoryAsync<TRepo>(ct);
203277

204278
AttachedRepositories.AddOrUpdate(
205279
typeof(TRepo),
206-
repo,
280+
repo,
207281
(_, _) => repo
208282
);
209283
return repo;
@@ -212,11 +286,11 @@ public virtual async ValueTask<TRepo> GetRepositoryAsync<TRepo>(CancellationToke
212286
private TRepo CreateAndAttachRepository<TRepo>() where TRepo : class, IUnitOfWorkRepository {
213287
var repo = serviceScope.ServiceProvider.GetRequiredService<TRepo>();
214288
if (repo is not UnitOfWorkRepository<TDbContext> castedRepo) throw new InvalidCastException($"Cannot cast repository of type '{repo.GetType()}' to '{typeof(TRepo)}'");
215-
289+
216290
castedRepo.Attach(this);
217291
return repo;
218292
}
219-
293+
220294
private async ValueTask<TRepo> CreateAndAttachRepositoryAsync<TRepo>(CancellationToken ct = default) where TRepo : class, IUnitOfWorkRepository {
221295
var repo = serviceScope.ServiceProvider.GetRequiredService<TRepo>();
222296
if (repo is not UnitOfWorkRepository<TDbContext> castedRepo) throw new InvalidCastException($"Cannot cast repository of type '{repo.GetType()}' to '{typeof(TRepo)}'");
@@ -225,7 +299,7 @@ private async ValueTask<TRepo> CreateAndAttachRepositoryAsync<TRepo>(Cancellatio
225299
return repo;
226300
}
227301

228-
public virtual async ValueTask DisposeAsync() {
302+
public async virtual ValueTask DisposeAsync() {
229303
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
230304
if (_initLockAsync == null) {
231305
GC.SuppressFinalize(this);

src/CodeOfChaos.Extensions.EntityFrameworkCore/UnitOfWork/UnitOfWorkFactory.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,14 @@ namespace Microsoft.EntityFrameworkCore;
99
// ---------------------------------------------------------------------------------------------------------------------
1010
// Code
1111
// ---------------------------------------------------------------------------------------------------------------------
12-
public class UnitOfWorkFactory<TDbContext>(IDbContextFactory<TDbContext> dbContextFactory, IServiceProvider provider, ILogger<UnitOfWorkFactory<TDbContext>> logger) : IUnitOfWorkFactory<TDbContext> where TDbContext : DbContext {
12+
public class UnitOfWorkFactory<TDbContext>(IDbContextFactory<TDbContext> dbContextFactory, IServiceProvider provider, ILogger<UnitOfWorkFactory<TDbContext>> logger, ILoggerFactory loggerFactory) : IUnitOfWorkFactory<TDbContext> where TDbContext : DbContext {
1313
public IUnitOfWork<TDbContext> Create() {
1414
// Each unit of work should have their own scope which they pull their repositories from
1515
// This, if the factory is used correctly, should enforce correct usage and limit dbcontext concurrency issues.
1616
AsyncServiceScope scope = provider.CreateAsyncScope();
1717

1818
// Because our factory doesn't create the actual dbcontext, yet we are safe, and we can just inject it downwards.
19-
return new UnitOfWork<TDbContext>(dbContextFactory, scope);
19+
return new UnitOfWork<TDbContext>(dbContextFactory, scope, loggerFactory.CreateLogger<UnitOfWork<TDbContext>>());
2020
}
2121

2222
public IUnitOfWork<TDbContext> CreateWithTransaction() {

0 commit comments

Comments
 (0)