33// ---------------------------------------------------------------------------------------------------------------------
44using Microsoft . EntityFrameworkCore . Storage ;
55using Microsoft . Extensions . DependencyInjection ;
6+ using Microsoft . Extensions . Logging ;
67using System . Collections . Concurrent ;
78using System . Runtime . CompilerServices ;
89
910namespace Microsoft . EntityFrameworkCore ;
1011// ---------------------------------------------------------------------------------------------------------------------
1112// Code
1213// ---------------------------------------------------------------------------------------------------------------------
13- public class UnitOfWork < TDbContext > ( IDbContextFactory < TDbContext > dbContextFactory , AsyncServiceScope serviceScope ) : IUnitOfWork < TDbContext > where TDbContext : DbContext {
14+ public class UnitOfWork < TDbContext > ( IDbContextFactory < TDbContext > dbContextFactory , AsyncServiceScope serviceScope , ILogger < UnitOfWork < TDbContext > > logger ) : IUnitOfWork < TDbContext > where TDbContext : DbContext {
1415 private TDbContext ? _dbContext ;
1516 private IDbContextTransaction ? _transaction ;
1617 private readonly ConcurrentDictionary < Type , IUnitOfWorkRepository > AttachedRepositories = [ ] ;
1718 private readonly SemaphoreSlim _initLockAsync = new ( 1 , 1 ) ;
1819 private readonly SemaphoreSlim _transactionLockAsync = new ( 1 , 1 ) ;
1920 private readonly Lock _transactionLock = new ( ) ;
2021 private readonly Lazy < TDbContext > _lazyDb = new ( dbContextFactory . CreateDbContext , LazyThreadSafetyMode . ExecutionAndPublication ) ;
21-
22+
2223 internal bool IsDisposed { get ; private set ; }
23-
24+
2425 // -----------------------------------------------------------------------------------------------------------------
2526 // Methods
2627 // -----------------------------------------------------------------------------------------------------------------
27- public virtual async ValueTask < TDbContext > GetDbContextAsync ( CancellationToken ct ) {
28+ public async virtual ValueTask < TDbContext > GetDbContextAsync ( CancellationToken ct ) {
2829 if ( _dbContext != null ) return _dbContext ;
2930
3031 await _initLockAsync . WaitAsync ( ct ) ;
@@ -44,7 +45,7 @@ public virtual async ValueTask<TDbContext> GetDbContextAsync(CancellationToken c
4445
4546 public virtual bool TryCreateTransaction ( ) {
4647 if ( _transaction != null ) return false ;
47-
48+
4849 var dbContext = GetDbContext < TDbContext > ( ) ;
4950 if ( dbContext . Database . CurrentTransaction != null ) {
5051 _transaction = dbContext . Database . CurrentTransaction ;
@@ -56,8 +57,8 @@ public virtual bool TryCreateTransaction() {
5657 return true ;
5758 }
5859 }
59-
60- public virtual async ValueTask < bool > TryCreateTransactionAsync ( CancellationToken ct = default ) {
60+
61+ public async virtual ValueTask < bool > TryCreateTransactionAsync ( CancellationToken ct = default ) {
6162 if ( _transaction != null ) return false ;
6263
6364 TDbContext dbContext = await GetDbContextAsync ( ct ) ;
@@ -71,6 +72,11 @@ public virtual async ValueTask<bool> TryCreateTransactionAsync(CancellationToken
7172 _transaction = await dbContext . Database . BeginTransactionAsync ( ct ) ;
7273 return true ;
7374 }
75+ catch ( Exception ex ) {
76+ logger . LogError ( ex , "Failed to begin transaction." ) ;
77+ _transaction = null ;
78+ return false ;
79+ }
7480 finally {
7581 _transactionLockAsync . Release ( ) ;
7682 }
@@ -80,24 +86,33 @@ public virtual void SaveChanges() {
8086 DbContext dbContext = GetDbContext < TDbContext > ( ) ;
8187 dbContext . SaveChanges ( ) ;
8288 }
83-
84- public virtual async ValueTask SaveChangesAsync ( CancellationToken ct = default ) {
89+
90+ public async virtual ValueTask SaveChangesAsync ( CancellationToken ct = default ) {
8591 DbContext dbContext = await GetDbContextAsync ( ct ) ;
8692 await dbContext . SaveChangesAsync ( ct ) ;
8793 }
88-
94+
8995 public virtual bool TryCommitTransaction ( ) {
9096 if ( _transaction == null ) return false ;
9197
9298 lock ( _transactionLock ) {
93- _transaction . Commit ( ) ;
94- _transaction . Dispose ( ) ;
95- _transaction = null ;
96- return true ;
99+ try {
100+ _transaction . Commit ( ) ;
101+ _transaction . Dispose ( ) ;
102+ _transaction = null ;
103+ return true ;
104+ }
105+ catch ( Exception ex ) {
106+ logger . LogError ( ex , "Failed to commit transaction." ) ;
107+ lock ( _transactionLock ) {
108+ _transaction = null ;
109+ }
110+ return false ;
111+ }
97112 }
98113 }
99114
100- public virtual async ValueTask < bool > TryCommitTransactionAsync ( CancellationToken ct = default ) {
115+ public async virtual ValueTask < bool > TryCommitTransactionAsync ( CancellationToken ct = default ) {
101116 if ( _transaction == null ) return false ;
102117
103118 await _transactionLockAsync . WaitAsync ( ct ) ;
@@ -107,103 +122,162 @@ public virtual async ValueTask<bool> TryCommitTransactionAsync(CancellationToken
107122 _transaction = null ;
108123 return true ;
109124 }
125+ catch ( Exception ex ) {
126+ logger . LogError ( ex , "Failed to commit transaction." ) ;
127+ _transaction = null ;
128+ return false ;
129+ }
110130 finally {
111131 _transactionLockAsync . Release ( ) ;
112132 }
113133 }
114134
115135 public virtual bool TryRollbackTransaction ( ) {
116136 if ( _transaction == null ) return false ;
117-
118- _transaction . Rollback ( ) ;
119- _transaction . Dispose ( ) ;
120- _transaction = null ;
121-
122- return true ;
137+
138+ lock ( _transactionLock ) {
139+ try {
140+ _transaction . Rollback ( ) ;
141+ _transaction . Dispose ( ) ;
142+ _transaction = null ;
143+ return true ;
144+ }
145+ catch ( Exception ex ) {
146+ logger . LogError ( ex , "Failed to rollback transaction." ) ;
147+ lock ( _transactionLock ) {
148+ _transaction = null ;
149+ }
150+ return false ;
151+ }
152+ }
123153 }
124154
125- public virtual async ValueTask < bool > TryRollbackTransactionAsync ( CancellationToken ct = default ) {
155+ public async virtual ValueTask < bool > TryRollbackTransactionAsync ( CancellationToken ct = default ) {
126156 if ( _transaction == null ) return false ;
127157
128- await _transaction . RollbackAsync ( ct ) ;
129- await _transaction . DisposeAsync ( ) ;
130- _transaction = null ;
131-
132- return true ;
158+ await _transactionLockAsync . WaitAsync ( ct ) ;
159+ try {
160+ await _transaction . RollbackAsync ( ct ) ;
161+ await _transaction . DisposeAsync ( ) ;
162+ _transaction = null ;
163+ return true ;
164+ }
165+ catch ( Exception ex ) {
166+ logger . LogError ( ex , "Failed to rollback transaction." ) ;
167+ _transaction = null ;
168+ return false ;
169+ }
170+ finally {
171+ _transactionLockAsync . Release ( ) ;
172+ }
133173 }
134174
135175 public virtual bool TryRollbackToSavepoint ( Guid id ) {
136176 if ( _transaction == null ) return false ;
137177 if ( ! _transaction . SupportsSavepoints ) return false ;
138-
139- _transaction . RollbackToSavepoint ( id . ToString ( "N" ) ) ;
140- return true ;
178+
179+ lock ( _transactionLock ) {
180+ try {
181+ _transaction . RollbackToSavepoint ( id . ToString ( "N" ) ) ;
182+ return true ;
183+ }
184+ catch ( Exception ex ) {
185+ logger . LogError ( ex , "Failed to rollback to savepoint." ) ;
186+ return false ;
187+ }
188+ }
141189 }
142190
143- public virtual async ValueTask < bool > TryRollbackToSavepointAsync ( Guid id , CancellationToken ct = default ) {
191+ public async virtual ValueTask < bool > TryRollbackToSavepointAsync ( Guid id , CancellationToken ct = default ) {
144192 if ( _transaction == null ) return false ;
145193 if ( ! _transaction . SupportsSavepoints ) return false ;
146194
147- await _transaction . RollbackToSavepointAsync ( id . ToString ( "N" ) , ct ) ;
148-
149- return true ;
195+ await _transactionLockAsync . WaitAsync ( ct ) ;
196+ try {
197+ await _transaction . RollbackToSavepointAsync ( id . ToString ( "N" ) , ct ) ;
198+ return true ;
199+ }
200+ catch ( Exception ex ) {
201+ logger . LogError ( ex , "Failed to rollback to savepoint." ) ;
202+ return false ;
203+ }
204+ finally {
205+ _transactionLockAsync . Release ( ) ;
206+ }
150207 }
151-
208+
152209 public virtual bool TryCreateSavepoint ( Guid id ) {
153210 if ( _transaction == null ) return false ;
154211 if ( ! _transaction . SupportsSavepoints ) return false ;
155-
156- _transaction . CreateSavepoint ( id . ToString ( "N" ) ) ;
157- return true ;
212+
213+ lock ( _transactionLock ) {
214+ try {
215+ _transaction . CreateSavepoint ( id . ToString ( "N" ) ) ;
216+ return true ;
217+ }
218+ catch ( Exception ex ) {
219+ logger . LogError ( ex , "Failed to create savepoint." ) ;
220+ return false ;
221+ }
222+ }
158223 }
159224
160- public virtual async ValueTask < bool > TryCreateSavepointAsync ( Guid id , CancellationToken ct = default ) {
225+ public async virtual ValueTask < bool > TryCreateSavepointAsync ( Guid id , CancellationToken ct = default ) {
161226 if ( _transaction == null ) return false ;
162227 if ( ! _transaction . SupportsSavepoints ) return false ;
163228
164- await _transaction . CreateSavepointAsync ( id . ToString ( "N" ) , ct ) ;
165-
166- return true ;
229+ await _transactionLockAsync . WaitAsync ( ct ) ;
230+ try {
231+ await _transaction . CreateSavepointAsync ( id . ToString ( "N" ) , ct ) ;
232+ return true ;
233+ }
234+ catch ( Exception ex ) {
235+ logger . LogError ( ex , "Failed to create savepoint." ) ;
236+ return false ;
237+ }
238+ finally {
239+ _transactionLockAsync . Release ( ) ;
240+ }
167241 }
168-
242+
169243 public virtual T GetDbContext < T > ( ) where T : DbContext {
170244 if ( typeof ( T ) != typeof ( TDbContext ) ) throw new NotSupportedException ( $ "DbContext type '{ typeof ( T ) } ' is not supported by this UnitOfWork.") ;
171-
245+
172246 TDbContext dbContext = GetDbContext ( ) ;
173-
247+
174248 return Unsafe . As < TDbContext , T > ( ref dbContext ) ;
175249 }
176250
177- public virtual async ValueTask < T > GetDbContextAsync < T > ( CancellationToken ct = default ) where T : DbContext {
251+ public async virtual ValueTask < T > GetDbContextAsync < T > ( CancellationToken ct = default ) where T : DbContext {
178252 if ( typeof ( T ) != typeof ( TDbContext ) ) throw new NotSupportedException ( $ "DbContext type '{ typeof ( T ) } ' is not supported by this UnitOfWork.") ;
179253
180254 TDbContext dbContext = await GetDbContextAsync ( ct ) ;
181-
255+
182256 return Unsafe . As < TDbContext , T > ( ref dbContext ) ;
183257 }
184258
185259 public virtual TRepo GetRepository < TRepo > ( ) where TRepo : class , IUnitOfWorkRepository {
186260 if ( AttachedRepositories . TryGetValue ( typeof ( TRepo ) , out IUnitOfWorkRepository ? cachedRepo ) && cachedRepo is TRepo castedCachedRepo ) return castedCachedRepo ;
187-
261+
188262 var repo = CreateAndAttachRepository < TRepo > ( ) ;
189-
263+
190264 AttachedRepositories . AddOrUpdate (
191265 typeof ( TRepo ) ,
192- repo ,
266+ repo ,
193267 ( _ , _ ) => repo
194268 ) ;
195269 return repo ;
196270 }
197271
198- public virtual async ValueTask < TRepo > GetRepositoryAsync < TRepo > ( CancellationToken ct = default ) where TRepo : class , IUnitOfWorkRepository {
272+ public async virtual ValueTask < TRepo > GetRepositoryAsync < TRepo > ( CancellationToken ct = default ) where TRepo : class , IUnitOfWorkRepository {
199273 if ( AttachedRepositories . TryGetValue ( typeof ( TRepo ) , out IUnitOfWorkRepository ? cachedRepo ) && cachedRepo is TRepo castedCachedRepo ) return castedCachedRepo ;
200274
201275 // Cache miss so we create a new instance
202276 var repo = await CreateAndAttachRepositoryAsync < TRepo > ( ct ) ;
203277
204278 AttachedRepositories . AddOrUpdate (
205279 typeof ( TRepo ) ,
206- repo ,
280+ repo ,
207281 ( _ , _ ) => repo
208282 ) ;
209283 return repo ;
@@ -212,11 +286,11 @@ public virtual async ValueTask<TRepo> GetRepositoryAsync<TRepo>(CancellationToke
212286 private TRepo CreateAndAttachRepository < TRepo > ( ) where TRepo : class , IUnitOfWorkRepository {
213287 var repo = serviceScope . ServiceProvider . GetRequiredService < TRepo > ( ) ;
214288 if ( repo is not UnitOfWorkRepository < TDbContext > castedRepo ) throw new InvalidCastException ( $ "Cannot cast repository of type '{ repo . GetType ( ) } ' to '{ typeof ( TRepo ) } '") ;
215-
289+
216290 castedRepo . Attach ( this ) ;
217291 return repo ;
218292 }
219-
293+
220294 private async ValueTask < TRepo > CreateAndAttachRepositoryAsync < TRepo > ( CancellationToken ct = default ) where TRepo : class , IUnitOfWorkRepository {
221295 var repo = serviceScope . ServiceProvider . GetRequiredService < TRepo > ( ) ;
222296 if ( repo is not UnitOfWorkRepository < TDbContext > castedRepo ) throw new InvalidCastException ( $ "Cannot cast repository of type '{ repo . GetType ( ) } ' to '{ typeof ( TRepo ) } '") ;
@@ -225,7 +299,7 @@ private async ValueTask<TRepo> CreateAndAttachRepositoryAsync<TRepo>(Cancellatio
225299 return repo ;
226300 }
227301
228- public virtual async ValueTask DisposeAsync ( ) {
302+ public async virtual ValueTask DisposeAsync ( ) {
229303 // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
230304 if ( _initLockAsync == null ) {
231305 GC . SuppressFinalize ( this ) ;
0 commit comments