@@ -19,6 +19,11 @@ public EFCoreAdapter(CasbinDbContext<TKey> context) : base(context)
1919 {
2020
2121 }
22+
23+ public EFCoreAdapter ( IServiceProvider serviceProvider ) : base ( serviceProvider )
24+ {
25+
26+ }
2227 }
2328
2429 public class EFCoreAdapter < TKey , TPersistPolicy > : EFCoreAdapter < TKey , TPersistPolicy , CasbinDbContext < TKey > >
@@ -31,6 +36,11 @@ public EFCoreAdapter(CasbinDbContext<TKey> context) : base(context)
3136 {
3237
3338 }
39+
40+ public EFCoreAdapter ( IServiceProvider serviceProvider ) : base ( serviceProvider )
41+ {
42+
43+ }
3444 }
3545
3646 public partial class EFCoreAdapter < TKey , TPersistPolicy , TDbContext > : IAdapter , IFilteredAdapter
@@ -39,12 +49,32 @@ public partial class EFCoreAdapter<TKey, TPersistPolicy, TDbContext> : IAdapter,
3949 where TKey : IEquatable < TKey >
4050 {
4151 private DbSet < TPersistPolicy > _persistPolicies ;
42- protected TDbContext DbContext { get ; }
43- protected DbSet < TPersistPolicy > PersistPolicies => _persistPolicies ??= GetCasbinRuleDbSet ( DbContext ) ;
52+ private readonly IServiceProvider _serviceProvider ;
53+ private readonly bool _useServiceProvider ;
54+
55+ protected TDbContext DbContext { get ; private set ; }
56+ protected DbSet < TPersistPolicy > PersistPolicies => _persistPolicies ??= GetCasbinRuleDbSet ( GetOrResolveDbContext ( ) ) ;
4457
4558 public EFCoreAdapter ( TDbContext context )
4659 {
4760 DbContext = context ?? throw new ArgumentNullException ( nameof ( context ) ) ;
61+ _useServiceProvider = false ;
62+ }
63+
64+ public EFCoreAdapter ( IServiceProvider serviceProvider )
65+ {
66+ _serviceProvider = serviceProvider ?? throw new ArgumentNullException ( nameof ( serviceProvider ) ) ;
67+ _useServiceProvider = true ;
68+ }
69+
70+ private TDbContext GetOrResolveDbContext ( )
71+ {
72+ if ( _useServiceProvider )
73+ {
74+ return _serviceProvider . GetService ( typeof ( TDbContext ) ) as TDbContext
75+ ?? throw new InvalidOperationException ( $ "Unable to resolve service for type '{ typeof ( TDbContext ) } ' from IServiceProvider.") ;
76+ }
77+ return DbContext ;
4878 }
4979
5080 #region Load policy
@@ -71,6 +101,7 @@ public virtual async Task LoadPolicyAsync(IPolicyStore store)
71101
72102 public virtual void SavePolicy ( IPolicyStore store )
73103 {
104+ var dbContext = GetOrResolveDbContext ( ) ;
74105 var persistPolicies = new List < TPersistPolicy > ( ) ;
75106 persistPolicies . ReadPolicyFromCasbinModel ( store ) ;
76107
@@ -81,15 +112,16 @@ public virtual void SavePolicy(IPolicyStore store)
81112
82113 var existRule = PersistPolicies . ToList ( ) ;
83114 PersistPolicies . RemoveRange ( existRule ) ;
84- DbContext . SaveChanges ( ) ;
115+ dbContext . SaveChanges ( ) ;
85116
86117 var saveRules = OnSavePolicy ( store , persistPolicies ) ;
87118 PersistPolicies . AddRange ( saveRules ) ;
88- DbContext . SaveChanges ( ) ;
119+ dbContext . SaveChanges ( ) ;
89120 }
90121
91122 public virtual async Task SavePolicyAsync ( IPolicyStore store )
92123 {
124+ var dbContext = GetOrResolveDbContext ( ) ;
93125 var persistPolicies = new List < TPersistPolicy > ( ) ;
94126 persistPolicies . ReadPolicyFromCasbinModel ( store ) ;
95127
@@ -100,11 +132,11 @@ public virtual async Task SavePolicyAsync(IPolicyStore store)
100132
101133 var existRule = PersistPolicies . ToList ( ) ;
102134 PersistPolicies . RemoveRange ( existRule ) ;
103- await DbContext . SaveChangesAsync ( ) ;
135+ await dbContext . SaveChangesAsync ( ) ;
104136
105137 var saveRules = OnSavePolicy ( store , persistPolicies ) ;
106138 await PersistPolicies . AddRangeAsync ( saveRules ) ;
107- await DbContext . SaveChangesAsync ( ) ;
139+ await dbContext . SaveChangesAsync ( ) ;
108140 }
109141
110142 #endregion
@@ -113,6 +145,7 @@ public virtual async Task SavePolicyAsync(IPolicyStore store)
113145
114146 public virtual void AddPolicy ( string section , string policyType , IPolicyValues values )
115147 {
148+ var dbContext = GetOrResolveDbContext ( ) ;
116149 if ( values . Count is 0 )
117150 {
118151 return ;
@@ -127,11 +160,12 @@ public virtual void AddPolicy(string section, string policyType, IPolicyValues v
127160 }
128161
129162 InternalAddPolicy ( section , policyType , values ) ;
130- DbContext . SaveChanges ( ) ;
163+ dbContext . SaveChanges ( ) ;
131164 }
132165
133166 public virtual async Task AddPolicyAsync ( string section , string policyType , IPolicyValues values )
134167 {
168+ var dbContext = GetOrResolveDbContext ( ) ;
135169 if ( values . Count is 0 )
136170 {
137171 return ;
@@ -146,27 +180,29 @@ public virtual async Task AddPolicyAsync(string section, string policyType, IPol
146180 }
147181
148182 await InternalAddPolicyAsync ( section , policyType , values ) ;
149- await DbContext . SaveChangesAsync ( ) ;
183+ await dbContext . SaveChangesAsync ( ) ;
150184 }
151185
152186 public virtual void AddPolicies ( string section , string policyType , IReadOnlyList < IPolicyValues > valuesList )
153187 {
188+ var dbContext = GetOrResolveDbContext ( ) ;
154189 if ( valuesList . Count is 0 )
155190 {
156191 return ;
157192 }
158193 InternalAddPolicies ( section , policyType , valuesList ) ;
159- DbContext . SaveChanges ( ) ;
194+ dbContext . SaveChanges ( ) ;
160195 }
161196
162197 public virtual async Task AddPoliciesAsync ( string section , string policyType , IReadOnlyList < IPolicyValues > valuesList )
163198 {
199+ var dbContext = GetOrResolveDbContext ( ) ;
164200 if ( valuesList . Count is 0 )
165201 {
166202 return ;
167203 }
168204 await InternalAddPoliciesAsync ( section , policyType , valuesList ) ;
169- await DbContext . SaveChangesAsync ( ) ;
205+ await dbContext . SaveChangesAsync ( ) ;
170206 }
171207
172208 #endregion
@@ -175,63 +211,69 @@ public virtual async Task AddPoliciesAsync(string section, string policyType, IR
175211
176212 public virtual void RemovePolicy ( string section , string policyType , IPolicyValues values )
177213 {
214+ var dbContext = GetOrResolveDbContext ( ) ;
178215 if ( values . Count is 0 )
179216 {
180217 return ;
181218 }
182219 InternalRemovePolicy ( section , policyType , values ) ;
183- DbContext . SaveChanges ( ) ;
220+ dbContext . SaveChanges ( ) ;
184221 }
185222
186223 public virtual async Task RemovePolicyAsync ( string section , string policyType , IPolicyValues values )
187224 {
225+ var dbContext = GetOrResolveDbContext ( ) ;
188226 if ( values . Count is 0 )
189227 {
190228 return ;
191229 }
192230 InternalRemovePolicy ( section , policyType , values ) ;
193- await DbContext . SaveChangesAsync ( ) ;
231+ await dbContext . SaveChangesAsync ( ) ;
194232 }
195233
196234 public virtual void RemoveFilteredPolicy ( string section , string policyType , int fieldIndex , IPolicyValues fieldValues )
197235 {
236+ var dbContext = GetOrResolveDbContext ( ) ;
198237 if ( fieldValues . Count is 0 )
199238 {
200239 return ;
201240 }
202241 InternalRemoveFilteredPolicy ( section , policyType , fieldIndex , fieldValues ) ;
203- DbContext . SaveChanges ( ) ;
242+ dbContext . SaveChanges ( ) ;
204243 }
205244
206245 public virtual async Task RemoveFilteredPolicyAsync ( string section , string policyType , int fieldIndex , IPolicyValues fieldValues )
207246 {
247+ var dbContext = GetOrResolveDbContext ( ) ;
208248 if ( fieldValues . Count is 0 )
209249 {
210250 return ;
211251 }
212252 InternalRemoveFilteredPolicy ( section , policyType , fieldIndex , fieldValues ) ;
213- await DbContext . SaveChangesAsync ( ) ;
253+ await dbContext . SaveChangesAsync ( ) ;
214254 }
215255
216256
217257 public virtual void RemovePolicies ( string section , string policyType , IReadOnlyList < IPolicyValues > valuesList )
218258 {
259+ var dbContext = GetOrResolveDbContext ( ) ;
219260 if ( valuesList . Count is 0 )
220261 {
221262 return ;
222263 }
223264 InternalRemovePolicies ( section , policyType , valuesList ) ;
224- DbContext . SaveChanges ( ) ;
265+ dbContext . SaveChanges ( ) ;
225266 }
226267
227268 public virtual async Task RemovePoliciesAsync ( string section , string policyType , IReadOnlyList < IPolicyValues > valuesList )
228269 {
270+ var dbContext = GetOrResolveDbContext ( ) ;
229271 if ( valuesList . Count is 0 )
230272 {
231273 return ;
232274 }
233275 InternalRemovePolicies ( section , policyType , valuesList ) ;
234- await DbContext . SaveChangesAsync ( ) ;
276+ await dbContext . SaveChangesAsync ( ) ;
235277 }
236278
237279 #endregion
@@ -240,49 +282,53 @@ public virtual async Task RemovePoliciesAsync(string section, string policyType,
240282
241283 public void UpdatePolicy ( string section , string policyType , IPolicyValues oldValues , IPolicyValues newValues )
242284 {
285+ var dbContext = GetOrResolveDbContext ( ) ;
243286 if ( newValues . Count is 0 )
244287 {
245288 return ;
246289 }
247- using var transaction = DbContext . Database . BeginTransaction ( ) ;
290+ using var transaction = dbContext . Database . BeginTransaction ( ) ;
248291 InternalUpdatePolicy ( section , policyType , oldValues , newValues ) ;
249- DbContext . SaveChanges ( ) ;
292+ dbContext . SaveChanges ( ) ;
250293 transaction . Commit ( ) ;
251294 }
252295
253296 public async Task UpdatePolicyAsync ( string section , string policyType , IPolicyValues oldValues , IPolicyValues newValues )
254297 {
298+ var dbContext = GetOrResolveDbContext ( ) ;
255299 if ( newValues . Count is 0 )
256300 {
257301 return ;
258302 }
259- await using var transaction = await DbContext . Database . BeginTransactionAsync ( ) ;
303+ await using var transaction = await dbContext . Database . BeginTransactionAsync ( ) ;
260304 await InternalUpdatePolicyAsync ( section , policyType , oldValues , newValues ) ;
261- await DbContext . SaveChangesAsync ( ) ;
305+ await dbContext . SaveChangesAsync ( ) ;
262306 await transaction . CommitAsync ( ) ;
263307 }
264308
265309 public void UpdatePolicies ( string section , string policyType , IReadOnlyList < IPolicyValues > oldValuesList , IReadOnlyList < IPolicyValues > newValuesList )
266310 {
311+ var dbContext = GetOrResolveDbContext ( ) ;
267312 if ( newValuesList . Count is 0 )
268313 {
269314 return ;
270315 }
271- using var transaction = DbContext . Database . BeginTransaction ( ) ;
316+ using var transaction = dbContext . Database . BeginTransaction ( ) ;
272317 InternalUpdatePolicies ( section , policyType , oldValuesList , newValuesList ) ;
273- DbContext . SaveChanges ( ) ;
318+ dbContext . SaveChanges ( ) ;
274319 transaction . Commit ( ) ;
275320 }
276321
277322 public async Task UpdatePoliciesAsync ( string section , string policyType , IReadOnlyList < IPolicyValues > oldValuesList , IReadOnlyList < IPolicyValues > newValuesList )
278323 {
324+ var dbContext = GetOrResolveDbContext ( ) ;
279325 if ( newValuesList . Count is 0 )
280326 {
281327 return ;
282328 }
283- await using var transaction = await DbContext . Database . BeginTransactionAsync ( ) ;
329+ await using var transaction = await dbContext . Database . BeginTransactionAsync ( ) ;
284330 await InternalUpdatePoliciesAsync ( section , policyType , oldValuesList , newValuesList ) ;
285- await DbContext . SaveChangesAsync ( ) ;
331+ await dbContext . SaveChangesAsync ( ) ;
286332 await transaction . CommitAsync ( ) ;
287333 }
288334
0 commit comments