diff --git a/Casbin.Persist.Adapter.EFCore.UnitTest/DependencyInjectionTest.cs b/Casbin.Persist.Adapter.EFCore.UnitTest/DependencyInjectionTest.cs index 83bd9b3..2f61846 100644 --- a/Casbin.Persist.Adapter.EFCore.UnitTest/DependencyInjectionTest.cs +++ b/Casbin.Persist.Adapter.EFCore.UnitTest/DependencyInjectionTest.cs @@ -1,16 +1,19 @@ using Microsoft.Extensions.DependencyInjection; using Casbin.Persist.Adapter.EFCore.UnitTest.Fixtures; using Xunit; +using Casbin.Model; namespace Casbin.Persist.Adapter.EFCore.UnitTest { - public class DependencyInjectionTest : IClassFixture + public class DependencyInjectionTest : IClassFixture, IClassFixture { private readonly TestHostFixture _testHostFixture; + private readonly ModelProvideFixture _modelProvideFixture; - public DependencyInjectionTest(TestHostFixture testHostFixture) + public DependencyInjectionTest(TestHostFixture testHostFixture, ModelProvideFixture modelProvideFixture) { _testHostFixture = testHostFixture; + _modelProvideFixture = modelProvideFixture; } [Fact] @@ -27,5 +30,62 @@ public void ShouldResolveEfCoreAdapter() var adapter = _testHostFixture.Services.GetService(); Assert.NotNull(adapter); } + + [Fact] + public void ShouldUseAdapterAcrossMultipleScopesWithDbContextDirectly() + { + // Simulate the issue where an adapter is created in one scope + // but used in another scope (like with casbin-aspnetcore) + IAdapter adapter; + + // Create adapter with DbContext in first scope + using (var scope1 = _testHostFixture.Services.CreateScope()) + { + var dbContext = scope1.ServiceProvider.GetRequiredService>(); + dbContext.Database.EnsureCreated(); + adapter = new EFCoreAdapter(dbContext); + } + + // Try to use adapter after scope is disposed - this should throw ObjectDisposedException + var model = _modelProvideFixture.GetNewRbacModel(); + Assert.Throws(() => adapter.LoadPolicy(model)); + } + + [Fact] + public void ShouldUseAdapterAcrossMultipleScopesWithServiceProvider() + { + // Create adapter with IServiceProvider - this should work across multiple scopes + var adapter = new EFCoreAdapter(_testHostFixture.Services); + + // Ensure database is created in first scope + using (var scope1 = _testHostFixture.Services.CreateScope()) + { + var dbContext = scope1.ServiceProvider.GetRequiredService>(); + dbContext.Database.EnsureCreated(); + } + + // Use adapter after scope is disposed - this should work with IServiceProvider + var model = _modelProvideFixture.GetNewRbacModel(); + adapter.LoadPolicy(model); // Should not throw + } + + [Fact] + public void ShouldResolveAdapterRegisteredWithExtensionMethod() + { + // The adapter registered via AddEFCoreAdapter extension should be resolvable + var adapter = _testHostFixture.Services.GetService(); + Assert.NotNull(adapter); + + // Create scope to ensure database exists + using (var scope = _testHostFixture.Services.CreateScope()) + { + var dbContext = scope.ServiceProvider.GetRequiredService>(); + dbContext.Database.EnsureCreated(); + } + + // Should be able to use the adapter + var model = _modelProvideFixture.GetNewRbacModel(); + adapter.LoadPolicy(model); // Should not throw + } } } \ No newline at end of file diff --git a/Casbin.Persist.Adapter.EFCore.UnitTest/Fixtures/TestHostFixture.cs b/Casbin.Persist.Adapter.EFCore.UnitTest/Fixtures/TestHostFixture.cs index 9b99d7d..97faf7d 100644 --- a/Casbin.Persist.Adapter.EFCore.UnitTest/Fixtures/TestHostFixture.cs +++ b/Casbin.Persist.Adapter.EFCore.UnitTest/Fixtures/TestHostFixture.cs @@ -2,6 +2,7 @@ using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.DependencyInjection; using System; +using Casbin.Persist.Adapter.EFCore.Extensions; namespace Casbin.Persist.Adapter.EFCore.UnitTest.Fixtures { @@ -14,7 +15,7 @@ public TestHostFixture() { options.UseSqlite("Data Source=CasbinHostTest.db"); }) - .AddScoped>() + .AddEFCoreAdapter() .BuildServiceProvider(); Server = new TestServer(Services); } diff --git a/Casbin.Persist.Adapter.EFCore/EFCoreAdapter.cs b/Casbin.Persist.Adapter.EFCore/EFCoreAdapter.cs index ef4b5af..3994b18 100644 --- a/Casbin.Persist.Adapter.EFCore/EFCoreAdapter.cs +++ b/Casbin.Persist.Adapter.EFCore/EFCoreAdapter.cs @@ -19,6 +19,11 @@ public EFCoreAdapter(CasbinDbContext context) : base(context) { } + + public EFCoreAdapter(IServiceProvider serviceProvider) : base(serviceProvider) + { + + } } public class EFCoreAdapter : EFCoreAdapter> @@ -31,6 +36,11 @@ public EFCoreAdapter(CasbinDbContext context) : base(context) { } + + public EFCoreAdapter(IServiceProvider serviceProvider) : base(serviceProvider) + { + + } } public partial class EFCoreAdapter : IAdapter, IFilteredAdapter @@ -39,12 +49,32 @@ public partial class EFCoreAdapter : IAdapter, where TKey : IEquatable { private DbSet _persistPolicies; - protected TDbContext DbContext { get; } - protected DbSet PersistPolicies => _persistPolicies ??= GetCasbinRuleDbSet(DbContext); + private readonly IServiceProvider _serviceProvider; + private readonly bool _useServiceProvider; + + protected TDbContext DbContext { get; private set; } + protected DbSet PersistPolicies => _persistPolicies ??= GetCasbinRuleDbSet(GetOrResolveDbContext()); public EFCoreAdapter(TDbContext context) { DbContext = context ?? throw new ArgumentNullException(nameof(context)); + _useServiceProvider = false; + } + + public EFCoreAdapter(IServiceProvider serviceProvider) + { + _serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider)); + _useServiceProvider = true; + } + + private TDbContext GetOrResolveDbContext() + { + if (_useServiceProvider) + { + return _serviceProvider.GetService(typeof(TDbContext)) as TDbContext + ?? throw new InvalidOperationException($"Unable to resolve service for type '{typeof(TDbContext)}' from IServiceProvider."); + } + return DbContext; } #region Load policy @@ -71,6 +101,7 @@ public virtual async Task LoadPolicyAsync(IPolicyStore store) public virtual void SavePolicy(IPolicyStore store) { + var dbContext = GetOrResolveDbContext(); var persistPolicies = new List(); persistPolicies.ReadPolicyFromCasbinModel(store); @@ -81,15 +112,16 @@ public virtual void SavePolicy(IPolicyStore store) var existRule = PersistPolicies.ToList(); PersistPolicies.RemoveRange(existRule); - DbContext.SaveChanges(); + dbContext.SaveChanges(); var saveRules = OnSavePolicy(store, persistPolicies); PersistPolicies.AddRange(saveRules); - DbContext.SaveChanges(); + dbContext.SaveChanges(); } public virtual async Task SavePolicyAsync(IPolicyStore store) { + var dbContext = GetOrResolveDbContext(); var persistPolicies = new List(); persistPolicies.ReadPolicyFromCasbinModel(store); @@ -100,11 +132,11 @@ public virtual async Task SavePolicyAsync(IPolicyStore store) var existRule = PersistPolicies.ToList(); PersistPolicies.RemoveRange(existRule); - await DbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); var saveRules = OnSavePolicy(store, persistPolicies); await PersistPolicies.AddRangeAsync(saveRules); - await DbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); } #endregion @@ -113,6 +145,7 @@ public virtual async Task SavePolicyAsync(IPolicyStore store) public virtual void AddPolicy(string section, string policyType, IPolicyValues values) { + var dbContext = GetOrResolveDbContext(); if (values.Count is 0) { return; @@ -127,11 +160,12 @@ public virtual void AddPolicy(string section, string policyType, IPolicyValues v } InternalAddPolicy(section, policyType, values); - DbContext.SaveChanges(); + dbContext.SaveChanges(); } public virtual async Task AddPolicyAsync(string section, string policyType, IPolicyValues values) { + var dbContext = GetOrResolveDbContext(); if (values.Count is 0) { return; @@ -146,27 +180,29 @@ public virtual async Task AddPolicyAsync(string section, string policyType, IPol } await InternalAddPolicyAsync(section, policyType, values); - await DbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); } public virtual void AddPolicies(string section, string policyType, IReadOnlyList valuesList) { + var dbContext = GetOrResolveDbContext(); if (valuesList.Count is 0) { return; } InternalAddPolicies(section, policyType, valuesList); - DbContext.SaveChanges(); + dbContext.SaveChanges(); } public virtual async Task AddPoliciesAsync(string section, string policyType, IReadOnlyList valuesList) { + var dbContext = GetOrResolveDbContext(); if (valuesList.Count is 0) { return; } await InternalAddPoliciesAsync(section, policyType, valuesList); - await DbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); } #endregion @@ -175,63 +211,69 @@ public virtual async Task AddPoliciesAsync(string section, string policyType, IR public virtual void RemovePolicy(string section, string policyType, IPolicyValues values) { + var dbContext = GetOrResolveDbContext(); if (values.Count is 0) { return; } InternalRemovePolicy(section, policyType, values); - DbContext.SaveChanges(); + dbContext.SaveChanges(); } public virtual async Task RemovePolicyAsync(string section, string policyType, IPolicyValues values) { + var dbContext = GetOrResolveDbContext(); if (values.Count is 0) { return; } InternalRemovePolicy(section, policyType, values); - await DbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); } public virtual void RemoveFilteredPolicy(string section, string policyType, int fieldIndex, IPolicyValues fieldValues) { + var dbContext = GetOrResolveDbContext(); if (fieldValues.Count is 0) { return; } InternalRemoveFilteredPolicy(section, policyType, fieldIndex, fieldValues); - DbContext.SaveChanges(); + dbContext.SaveChanges(); } public virtual async Task RemoveFilteredPolicyAsync(string section, string policyType, int fieldIndex, IPolicyValues fieldValues) { + var dbContext = GetOrResolveDbContext(); if (fieldValues.Count is 0) { return; } InternalRemoveFilteredPolicy(section, policyType, fieldIndex, fieldValues); - await DbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); } public virtual void RemovePolicies(string section, string policyType, IReadOnlyList valuesList) { + var dbContext = GetOrResolveDbContext(); if (valuesList.Count is 0) { return; } InternalRemovePolicies(section, policyType, valuesList); - DbContext.SaveChanges(); + dbContext.SaveChanges(); } public virtual async Task RemovePoliciesAsync(string section, string policyType, IReadOnlyList valuesList) { + var dbContext = GetOrResolveDbContext(); if (valuesList.Count is 0) { return; } InternalRemovePolicies(section, policyType, valuesList); - await DbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); } #endregion @@ -240,49 +282,53 @@ public virtual async Task RemovePoliciesAsync(string section, string policyType, public void UpdatePolicy(string section, string policyType, IPolicyValues oldValues, IPolicyValues newValues) { + var dbContext = GetOrResolveDbContext(); if (newValues.Count is 0) { return; } - using var transaction = DbContext.Database.BeginTransaction(); + using var transaction = dbContext.Database.BeginTransaction(); InternalUpdatePolicy(section, policyType, oldValues, newValues); - DbContext.SaveChanges(); + dbContext.SaveChanges(); transaction.Commit(); } public async Task UpdatePolicyAsync(string section, string policyType, IPolicyValues oldValues, IPolicyValues newValues) { + var dbContext = GetOrResolveDbContext(); if (newValues.Count is 0) { return; } - await using var transaction = await DbContext.Database.BeginTransactionAsync(); + await using var transaction = await dbContext.Database.BeginTransactionAsync(); await InternalUpdatePolicyAsync(section, policyType, oldValues, newValues); - await DbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); await transaction.CommitAsync(); } public void UpdatePolicies(string section, string policyType, IReadOnlyList oldValuesList, IReadOnlyList newValuesList) { + var dbContext = GetOrResolveDbContext(); if (newValuesList.Count is 0) { return; } - using var transaction = DbContext.Database.BeginTransaction(); + using var transaction = dbContext.Database.BeginTransaction(); InternalUpdatePolicies(section, policyType, oldValuesList, newValuesList); - DbContext.SaveChanges(); + dbContext.SaveChanges(); transaction.Commit(); } public async Task UpdatePoliciesAsync(string section, string policyType, IReadOnlyList oldValuesList, IReadOnlyList newValuesList) { + var dbContext = GetOrResolveDbContext(); if (newValuesList.Count is 0) { return; } - await using var transaction = await DbContext.Database.BeginTransactionAsync(); + await using var transaction = await dbContext.Database.BeginTransactionAsync(); await InternalUpdatePoliciesAsync(section, policyType, oldValuesList, newValuesList); - await DbContext.SaveChangesAsync(); + await dbContext.SaveChangesAsync(); await transaction.CommitAsync(); } diff --git a/Casbin.Persist.Adapter.EFCore/Extensions/ServiceCollectionExtensions.cs b/Casbin.Persist.Adapter.EFCore/Extensions/ServiceCollectionExtensions.cs new file mode 100644 index 0000000..beaac1a --- /dev/null +++ b/Casbin.Persist.Adapter.EFCore/Extensions/ServiceCollectionExtensions.cs @@ -0,0 +1,59 @@ +using System; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.DependencyInjection.Extensions; + +namespace Casbin.Persist.Adapter.EFCore.Extensions +{ + /// + /// Extension methods for registering EFCoreAdapter with dependency injection. + /// + public static class ServiceCollectionExtensions + { + /// + /// Adds the EFCoreAdapter to the service collection. + /// The adapter will resolve the DbContext from the service provider on each operation, + /// preventing issues with disposed contexts when used with long-lived services. + /// + /// The type of the primary key for the policy entities. + /// The service collection. + /// The service lifetime for the adapter. Default is Scoped. + /// The service collection for chaining. + public static IServiceCollection AddEFCoreAdapter( + this IServiceCollection services, + ServiceLifetime lifetime = ServiceLifetime.Scoped) where TKey : IEquatable + { + var descriptor = new ServiceDescriptor( + typeof(IAdapter), + sp => new EFCoreAdapter(sp), + lifetime); + + services.TryAdd(descriptor); + return services; + } + + /// + /// Adds the EFCoreAdapter with custom policy type to the service collection. + /// The adapter will resolve the DbContext from the service provider on each operation, + /// preventing issues with disposed contexts when used with long-lived services. + /// + /// The type of the primary key for the policy entities. + /// The type of the persist policy entity. + /// The service collection. + /// The service lifetime for the adapter. Default is Scoped. + /// The service collection for chaining. + public static IServiceCollection AddEFCoreAdapter( + this IServiceCollection services, + ServiceLifetime lifetime = ServiceLifetime.Scoped) + where TKey : IEquatable + where TPersistPolicy : class, IEFCorePersistPolicy, new() + { + var descriptor = new ServiceDescriptor( + typeof(IAdapter), + sp => new EFCoreAdapter(sp), + lifetime); + + services.TryAdd(descriptor); + return services; + } + } +} diff --git a/README.md b/README.md index b12d493..80bca69 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,51 @@ namespace ConsoleAppExample } ``` +## Using with Dependency Injection + +When using the adapter with dependency injection (e.g., in ASP.NET Core), you should use the `IServiceProvider` constructor or the extension method to avoid issues with disposed DbContext instances. + +### Recommended Approach (Using Extension Method) + +```csharp +using Casbin.Persist.Adapter.EFCore; +using Casbin.Persist.Adapter.EFCore.Extensions; +using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.DependencyInjection; + +// Register services +services.AddDbContext>(options => + options.UseSqlServer(connectionString)); + +// Register the adapter using the extension method +services.AddEFCoreAdapter(); + +// The adapter will resolve the DbContext from the service provider on each operation, +// preventing issues with disposed contexts when used with long-lived services. +``` + +### Alternative Approach (Using IServiceProvider Constructor) + +```csharp +// In your startup configuration +services.AddDbContext>(options => + options.UseSqlServer(connectionString)); + +services.AddCasbinAuthorization(options => +{ + options.DefaultModelPath = "model.conf"; + + // Use the IServiceProvider constructor + options.DefaultEnforcerFactory = (sp, model) => + new Enforcer(model, new EFCoreAdapter(sp)); +}); +``` + +This approach resolves the DbContext from the service provider on each database operation, ensuring that: +- The adapter works correctly with scoped DbContext instances +- No `ObjectDisposedException` is thrown when the adapter outlives the scope that created it +- The adapter can be used in long-lived services like singletons + ## Getting Help - [Casbin.NET](https://github.com/casbin/Casbin.NET)