55using Microsoft . EntityFrameworkCore . Storage ;
66using Microsoft . Extensions . DependencyInjection ;
77using System . Collections . Concurrent ;
8+ using System . Runtime . CompilerServices ;
89
910namespace CodeOfChaos . Types . UnitOfWork ;
1011// ---------------------------------------------------------------------------------------------------------------------
@@ -14,16 +15,18 @@ public class UnitOfWork<TDbContext>(IDbContextFactory<TDbContext> dbContextFacto
1415 private TDbContext ? _dbContext ;
1516 private IDbContextTransaction ? _transaction ;
1617 private readonly ConcurrentDictionary < Type , IUnitOfWorkRepository > AttachedRepositories = [ ] ;
17- private readonly SemaphoreSlim _initLock = new ( 1 , 1 ) ;
18-
19-
18+ private readonly SemaphoreSlim _initLockAsync = new ( 1 , 1 ) ;
19+ private readonly SemaphoreSlim _transactionLockAsync = new ( 1 , 1 ) ;
20+ private readonly Lock _transactionLock = new ( ) ;
21+ private readonly Lazy < TDbContext > _lazyDb = new ( dbContextFactory . CreateDbContext , LazyThreadSafetyMode . ExecutionAndPublication ) ;
22+
2023 // -----------------------------------------------------------------------------------------------------------------
2124 // Methods
2225 // -----------------------------------------------------------------------------------------------------------------
2326 protected virtual async ValueTask < TDbContext > GetDbContextAsync ( CancellationToken ct ) {
2427 if ( _dbContext != null ) return _dbContext ;
2528
26- await _initLock . WaitAsync ( ct ) ;
29+ await _initLockAsync . WaitAsync ( ct ) ;
2730 try {
2831 // Double-check pattern
2932 if ( _dbContext != null ) return _dbContext ;
@@ -32,7 +35,24 @@ protected virtual async ValueTask<TDbContext> GetDbContextAsync(CancellationToke
3235 return _dbContext ;
3336 }
3437 finally {
35- _initLock . Release ( ) ;
38+ _initLockAsync . Release ( ) ;
39+ }
40+ }
41+
42+ protected virtual TDbContext GetDbContext ( ) => _lazyDb . Value ;
43+
44+ public bool TryCreateTransaction ( ) {
45+ if ( _transaction != null ) return false ;
46+
47+ var dbContext = GetDbContext < TDbContext > ( ) ;
48+ if ( dbContext . Database . CurrentTransaction != null ) {
49+ _transaction = dbContext . Database . CurrentTransaction ;
50+ return true ;
51+ }
52+
53+ lock ( _transactionLock ) {
54+ _transaction = dbContext . Database . BeginTransaction ( ) ;
55+ return true ;
3656 }
3757 }
3858
@@ -45,36 +65,61 @@ public virtual async ValueTask<bool> TryCreateTransactionAsync(CancellationToken
4565 return true ;
4666 }
4767
48- await _initLock . WaitAsync ( ct ) ;
68+ await _transactionLockAsync . WaitAsync ( ct ) ;
4969 try {
5070 _transaction = await dbContext . Database . BeginTransactionAsync ( ct ) ;
5171 return true ;
5272 }
5373 finally {
54- _initLock . Release ( ) ;
74+ _transactionLockAsync . Release ( ) ;
5575 }
5676 }
77+
78+ public void SaveChanges ( ) {
79+ DbContext dbContext = GetDbContext < TDbContext > ( ) ;
80+ dbContext . SaveChanges ( ) ;
81+ }
5782
5883 public virtual async ValueTask SaveChangesAsync ( CancellationToken ct = default ) {
5984 DbContext dbContext = await GetDbContextAsync ( ct ) ;
6085 await dbContext . SaveChangesAsync ( ct ) ;
6186 }
87+
88+ public bool TryCommitTransaction ( ) {
89+ if ( _transaction == null ) return false ;
90+
91+ lock ( _transactionLock ) {
92+ _transaction . Commit ( ) ;
93+ _transaction . Dispose ( ) ;
94+ _transaction = null ;
95+ return true ;
96+ }
97+ }
6298
6399 public virtual async ValueTask < bool > TryCommitTransactionAsync ( CancellationToken ct = default ) {
64100 if ( _transaction == null ) return false ;
65101
66- await _initLock . WaitAsync ( ct ) ;
102+ await _transactionLockAsync . WaitAsync ( ct ) ;
67103 try {
68104 await _transaction . CommitAsync ( ct ) ;
69105 await _transaction . DisposeAsync ( ) ;
70106 _transaction = null ;
71107 return true ;
72108 }
73109 finally {
74- _initLock . Release ( ) ;
110+ _transactionLockAsync . Release ( ) ;
75111 }
76112 }
77113
114+ public bool TryRollbackTransaction ( ) {
115+ if ( _transaction == null ) return false ;
116+
117+ _transaction . Rollback ( ) ;
118+ _transaction . Dispose ( ) ;
119+ _transaction = null ;
120+
121+ return true ;
122+ }
78123
79124 public virtual async ValueTask < bool > TryRollbackTransactionAsync ( CancellationToken ct = default ) {
80125 if ( _transaction == null ) return false ;
@@ -86,6 +131,14 @@ public virtual async ValueTask<bool> TryRollbackTransactionAsync(CancellationTok
86131 return true ;
87132 }
88133
134+ public bool TryRollbackToSavepoint ( Guid id ) {
135+ if ( _transaction == null ) return false ;
136+ if ( ! _transaction . SupportsSavepoints ) return false ;
137+
138+ _transaction . RollbackToSavepoint ( id . ToString ( "N" ) ) ;
139+ return true ;
140+ }
141+
89142 public virtual async ValueTask < bool > TryRollbackToSavepointAsync ( Guid id , CancellationToken ct = default ) {
90143 if ( _transaction == null ) return false ;
91144 if ( ! _transaction . SupportsSavepoints ) return false ;
@@ -94,6 +147,14 @@ public virtual async ValueTask<bool> TryRollbackToSavepointAsync(Guid id, Cancel
94147
95148 return true ;
96149 }
150+
151+ public bool TryCreateSavepoint ( Guid id ) {
152+ if ( _transaction == null ) return false ;
153+ if ( ! _transaction . SupportsSavepoints ) return false ;
154+
155+ _transaction . CreateSavepoint ( id . ToString ( "N" ) ) ;
156+ return true ;
157+ }
97158
98159 public virtual async ValueTask < bool > TryCreateSavepointAsync ( Guid id , CancellationToken ct = default ) {
99160 if ( _transaction == null ) return false ;
@@ -103,12 +164,34 @@ public virtual async ValueTask<bool> TryCreateSavepointAsync(Guid id, Cancellati
103164
104165 return true ;
105166 }
167+
168+ public T GetDbContext < T > ( ) where T : DbContext {
169+ if ( typeof ( T ) != typeof ( TDbContext ) ) throw new NotSupportedException ( $ "DbContext type '{ typeof ( T ) } ' is not supported by this UnitOfWork.") ;
170+
171+ TDbContext dbContext = GetDbContext ( ) ;
172+
173+ return Unsafe . As < TDbContext , T > ( ref dbContext ) ;
174+ }
106175
107176 public virtual async ValueTask < T > GetDbContextAsync < T > ( CancellationToken ct = default ) where T : DbContext {
108177 if ( typeof ( T ) != typeof ( TDbContext ) ) throw new NotSupportedException ( $ "DbContext type '{ typeof ( T ) } ' is not supported by this UnitOfWork.") ;
109178
110179 TDbContext dbContext = await GetDbContextAsync ( ct ) ;
111- return dbContext as T ?? throw new InvalidCastException ( $ "Cannot cast DbContext of type '{ dbContext . GetType ( ) } ' to '{ typeof ( T ) } '") ;
180+
181+ return Unsafe . As < TDbContext , T > ( ref dbContext ) ;
182+ }
183+
184+ public TRepo GetRepository < TRepo > ( ) where TRepo : class , IUnitOfWorkRepository {
185+ if ( AttachedRepositories . TryGetValue ( typeof ( TRepo ) , out IUnitOfWorkRepository ? cachedRepo ) && cachedRepo is TRepo castedCachedRepo ) return castedCachedRepo ;
186+
187+ var repo = CreateAndAttachRepository < TRepo > ( ) ;
188+
189+ AttachedRepositories . AddOrUpdate (
190+ typeof ( TRepo ) ,
191+ repo ,
192+ ( _ , _ ) => repo
193+ ) ;
194+ return repo ;
112195 }
113196
114197 public virtual async ValueTask < TRepo > GetRepositoryAsync < TRepo > ( CancellationToken ct = default ) where TRepo : class , IUnitOfWorkRepository {
@@ -125,6 +208,14 @@ public virtual async ValueTask<TRepo> GetRepositoryAsync<TRepo>(CancellationToke
125208 return repo ;
126209 }
127210
211+ private TRepo CreateAndAttachRepository < TRepo > ( ) where TRepo : class , IUnitOfWorkRepository {
212+ var repo = serviceScope . ServiceProvider . GetRequiredService < TRepo > ( ) ;
213+ if ( repo is not UnitOfWorkRepository < TDbContext > castedRepo ) throw new InvalidCastException ( $ "Cannot cast repository of type '{ repo . GetType ( ) } ' to '{ typeof ( TRepo ) } '") ;
214+
215+ castedRepo . Attach ( this ) ;
216+ return repo ;
217+ }
218+
128219 private async ValueTask < TRepo > CreateAndAttachRepositoryAsync < TRepo > ( CancellationToken ct = default ) where TRepo : class , IUnitOfWorkRepository {
129220 var repo = serviceScope . ServiceProvider . GetRequiredService < TRepo > ( ) ;
130221 if ( repo is not UnitOfWorkRepository < TDbContext > castedRepo ) throw new InvalidCastException ( $ "Cannot cast repository of type '{ repo . GetType ( ) } ' to '{ typeof ( TRepo ) } '") ;
@@ -135,12 +226,12 @@ private async ValueTask<TRepo> CreateAndAttachRepositoryAsync<TRepo>(Cancellatio
135226
136227 public virtual async ValueTask DisposeAsync ( ) {
137228 // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
138- if ( _initLock == null ) {
229+ if ( _initLockAsync == null ) {
139230 GC . SuppressFinalize ( this ) ;
140231 return ;
141232 }
142233
143- await _initLock . WaitAsync ( ) ;
234+ await _initLockAsync . WaitAsync ( ) ;
144235 try {
145236 if ( _transaction != null ) {
146237 await TryRollbackTransactionAsync ( ) ;
@@ -162,8 +253,8 @@ public virtual async ValueTask DisposeAsync() {
162253 serviceScope . Dispose ( ) ;
163254 }
164255 finally {
165- _initLock . Release ( ) ;
166- _initLock . Dispose ( ) ;
256+ _initLockAsync . Release ( ) ;
257+ _initLockAsync . Dispose ( ) ;
167258 GC . SuppressFinalize ( this ) ;
168259 }
169260 }
0 commit comments