Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions Casbin.Persist.Adapter.EFCore.UnitTest/DependencyInjectionTest.cs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
using Microsoft.Extensions.DependencyInjection;
using Casbin.Persist.Adapter.EFCore.UnitTest.Fixtures;
using Xunit;
using Casbin.Model;

namespace Casbin.Persist.Adapter.EFCore.UnitTest
{
public class DependencyInjectionTest : IClassFixture<TestHostFixture>
public class DependencyInjectionTest : IClassFixture<TestHostFixture>, IClassFixture<ModelProvideFixture>
{
private readonly TestHostFixture _testHostFixture;
private readonly ModelProvideFixture _modelProvideFixture;

public DependencyInjectionTest(TestHostFixture testHostFixture)
public DependencyInjectionTest(TestHostFixture testHostFixture, ModelProvideFixture modelProvideFixture)
{
_testHostFixture = testHostFixture;
_modelProvideFixture = modelProvideFixture;
}

[Fact]
Expand All @@ -27,5 +30,62 @@ public void ShouldResolveEfCoreAdapter()
var adapter = _testHostFixture.Services.GetService<IAdapter>();
Assert.NotNull(adapter);
}

[Fact]
public void ShouldUseAdapterAcrossMultipleScopesWithDbContextDirectly()
{
// Simulate the issue where an adapter is created in one scope
// but used in another scope (like with casbin-aspnetcore)
IAdapter adapter;

// Create adapter with DbContext in first scope
using (var scope1 = _testHostFixture.Services.CreateScope())
{
var dbContext = scope1.ServiceProvider.GetRequiredService<CasbinDbContext<int>>();
dbContext.Database.EnsureCreated();
adapter = new EFCoreAdapter<int>(dbContext);
}

// Try to use adapter after scope is disposed - this should throw ObjectDisposedException
var model = _modelProvideFixture.GetNewRbacModel();
Assert.Throws<System.ObjectDisposedException>(() => adapter.LoadPolicy(model));
}

[Fact]
public void ShouldUseAdapterAcrossMultipleScopesWithServiceProvider()
{
// Create adapter with IServiceProvider - this should work across multiple scopes
var adapter = new EFCoreAdapter<int>(_testHostFixture.Services);

// Ensure database is created in first scope
using (var scope1 = _testHostFixture.Services.CreateScope())
{
var dbContext = scope1.ServiceProvider.GetRequiredService<CasbinDbContext<int>>();
dbContext.Database.EnsureCreated();
}

// Use adapter after scope is disposed - this should work with IServiceProvider
var model = _modelProvideFixture.GetNewRbacModel();
adapter.LoadPolicy(model); // Should not throw
}

[Fact]
public void ShouldResolveAdapterRegisteredWithExtensionMethod()
{
// The adapter registered via AddEFCoreAdapter extension should be resolvable
var adapter = _testHostFixture.Services.GetService<IAdapter>();
Assert.NotNull(adapter);

// Create scope to ensure database exists
using (var scope = _testHostFixture.Services.CreateScope())
{
var dbContext = scope.ServiceProvider.GetRequiredService<CasbinDbContext<int>>();
dbContext.Database.EnsureCreated();
}

// Should be able to use the adapter
var model = _modelProvideFixture.GetNewRbacModel();
adapter.LoadPolicy(model); // Should not throw
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.DependencyInjection;
using System;
using Casbin.Persist.Adapter.EFCore.Extensions;

namespace Casbin.Persist.Adapter.EFCore.UnitTest.Fixtures
{
Expand All @@ -14,7 +15,7 @@ public TestHostFixture()
{
options.UseSqlite("Data Source=CasbinHostTest.db");
})
.AddScoped<IAdapter, EFCoreAdapter<int>>()
.AddEFCoreAdapter<int>()
.BuildServiceProvider();
Server = new TestServer(Services);
}
Expand Down
94 changes: 70 additions & 24 deletions Casbin.Persist.Adapter.EFCore/EFCoreAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ public EFCoreAdapter(CasbinDbContext<TKey> context) : base(context)
{

}

public EFCoreAdapter(IServiceProvider serviceProvider) : base(serviceProvider)
{

}
}

public class EFCoreAdapter<TKey, TPersistPolicy> : EFCoreAdapter<TKey, TPersistPolicy, CasbinDbContext<TKey>>
Expand All @@ -31,6 +36,11 @@ public EFCoreAdapter(CasbinDbContext<TKey> context) : base(context)
{

}

public EFCoreAdapter(IServiceProvider serviceProvider) : base(serviceProvider)
{

}
}

public partial class EFCoreAdapter<TKey, TPersistPolicy, TDbContext> : IAdapter, IFilteredAdapter
Expand All @@ -39,12 +49,32 @@ public partial class EFCoreAdapter<TKey, TPersistPolicy, TDbContext> : IAdapter,
where TKey : IEquatable<TKey>
{
private DbSet<TPersistPolicy> _persistPolicies;
protected TDbContext DbContext { get; }
protected DbSet<TPersistPolicy> PersistPolicies => _persistPolicies ??= GetCasbinRuleDbSet(DbContext);
private readonly IServiceProvider _serviceProvider;
private readonly bool _useServiceProvider;

protected TDbContext DbContext { get; private set; }
protected DbSet<TPersistPolicy> PersistPolicies => _persistPolicies ??= GetCasbinRuleDbSet(GetOrResolveDbContext());

public EFCoreAdapter(TDbContext context)
{
DbContext = context ?? throw new ArgumentNullException(nameof(context));
_useServiceProvider = false;
}

public EFCoreAdapter(IServiceProvider serviceProvider)
{
_serviceProvider = serviceProvider ?? throw new ArgumentNullException(nameof(serviceProvider));
_useServiceProvider = true;
}

private TDbContext GetOrResolveDbContext()
{
if (_useServiceProvider)
{
return _serviceProvider.GetService(typeof(TDbContext)) as TDbContext
?? throw new InvalidOperationException($"Unable to resolve service for type '{typeof(TDbContext)}' from IServiceProvider.");
}
return DbContext;
}

#region Load policy
Expand All @@ -71,6 +101,7 @@ public virtual async Task LoadPolicyAsync(IPolicyStore store)

public virtual void SavePolicy(IPolicyStore store)
{
var dbContext = GetOrResolveDbContext();
var persistPolicies = new List<TPersistPolicy>();
persistPolicies.ReadPolicyFromCasbinModel(store);

Expand All @@ -81,15 +112,16 @@ public virtual void SavePolicy(IPolicyStore store)

var existRule = PersistPolicies.ToList();
PersistPolicies.RemoveRange(existRule);
DbContext.SaveChanges();
dbContext.SaveChanges();

var saveRules = OnSavePolicy(store, persistPolicies);
PersistPolicies.AddRange(saveRules);
DbContext.SaveChanges();
dbContext.SaveChanges();
}

public virtual async Task SavePolicyAsync(IPolicyStore store)
{
var dbContext = GetOrResolveDbContext();
var persistPolicies = new List<TPersistPolicy>();
persistPolicies.ReadPolicyFromCasbinModel(store);

Expand All @@ -100,11 +132,11 @@ public virtual async Task SavePolicyAsync(IPolicyStore store)

var existRule = PersistPolicies.ToList();
PersistPolicies.RemoveRange(existRule);
await DbContext.SaveChangesAsync();
await dbContext.SaveChangesAsync();

var saveRules = OnSavePolicy(store, persistPolicies);
await PersistPolicies.AddRangeAsync(saveRules);
await DbContext.SaveChangesAsync();
await dbContext.SaveChangesAsync();
}

#endregion
Expand All @@ -113,6 +145,7 @@ public virtual async Task SavePolicyAsync(IPolicyStore store)

public virtual void AddPolicy(string section, string policyType, IPolicyValues values)
{
var dbContext = GetOrResolveDbContext();
if (values.Count is 0)
{
return;
Expand All @@ -127,11 +160,12 @@ public virtual void AddPolicy(string section, string policyType, IPolicyValues v
}

InternalAddPolicy(section, policyType, values);
DbContext.SaveChanges();
dbContext.SaveChanges();
}

public virtual async Task AddPolicyAsync(string section, string policyType, IPolicyValues values)
{
var dbContext = GetOrResolveDbContext();
if (values.Count is 0)
{
return;
Expand All @@ -146,27 +180,29 @@ public virtual async Task AddPolicyAsync(string section, string policyType, IPol
}

await InternalAddPolicyAsync(section, policyType, values);
await DbContext.SaveChangesAsync();
await dbContext.SaveChangesAsync();
}

public virtual void AddPolicies(string section, string policyType, IReadOnlyList<IPolicyValues> valuesList)
{
var dbContext = GetOrResolveDbContext();
if (valuesList.Count is 0)
{
return;
}
InternalAddPolicies(section, policyType, valuesList);
DbContext.SaveChanges();
dbContext.SaveChanges();
}

public virtual async Task AddPoliciesAsync(string section, string policyType, IReadOnlyList<IPolicyValues> valuesList)
{
var dbContext = GetOrResolveDbContext();
if (valuesList.Count is 0)
{
return;
}
await InternalAddPoliciesAsync(section, policyType, valuesList);
await DbContext.SaveChangesAsync();
await dbContext.SaveChangesAsync();
}

#endregion
Expand All @@ -175,63 +211,69 @@ public virtual async Task AddPoliciesAsync(string section, string policyType, IR

public virtual void RemovePolicy(string section, string policyType, IPolicyValues values)
{
var dbContext = GetOrResolveDbContext();
if (values.Count is 0)
{
return;
}
InternalRemovePolicy(section, policyType, values);
DbContext.SaveChanges();
dbContext.SaveChanges();
}

public virtual async Task RemovePolicyAsync(string section, string policyType, IPolicyValues values)
{
var dbContext = GetOrResolveDbContext();
if (values.Count is 0)
{
return;
}
InternalRemovePolicy(section, policyType, values);
await DbContext.SaveChangesAsync();
await dbContext.SaveChangesAsync();
}

public virtual void RemoveFilteredPolicy(string section, string policyType, int fieldIndex, IPolicyValues fieldValues)
{
var dbContext = GetOrResolveDbContext();
if (fieldValues.Count is 0)
{
return;
}
InternalRemoveFilteredPolicy(section, policyType, fieldIndex, fieldValues);
DbContext.SaveChanges();
dbContext.SaveChanges();
}

public virtual async Task RemoveFilteredPolicyAsync(string section, string policyType, int fieldIndex, IPolicyValues fieldValues)
{
var dbContext = GetOrResolveDbContext();
if (fieldValues.Count is 0)
{
return;
}
InternalRemoveFilteredPolicy(section, policyType, fieldIndex, fieldValues);
await DbContext.SaveChangesAsync();
await dbContext.SaveChangesAsync();
}


public virtual void RemovePolicies(string section, string policyType, IReadOnlyList<IPolicyValues> valuesList)
{
var dbContext = GetOrResolveDbContext();
if (valuesList.Count is 0)
{
return;
}
InternalRemovePolicies(section, policyType, valuesList);
DbContext.SaveChanges();
dbContext.SaveChanges();
}

public virtual async Task RemovePoliciesAsync(string section, string policyType, IReadOnlyList<IPolicyValues> valuesList)
{
var dbContext = GetOrResolveDbContext();
if (valuesList.Count is 0)
{
return;
}
InternalRemovePolicies(section, policyType, valuesList);
await DbContext.SaveChangesAsync();
await dbContext.SaveChangesAsync();
}

#endregion
Expand All @@ -240,49 +282,53 @@ public virtual async Task RemovePoliciesAsync(string section, string policyType,

public void UpdatePolicy(string section, string policyType, IPolicyValues oldValues, IPolicyValues newValues)
{
var dbContext = GetOrResolveDbContext();
if (newValues.Count is 0)
{
return;
}
using var transaction = DbContext.Database.BeginTransaction();
using var transaction = dbContext.Database.BeginTransaction();
InternalUpdatePolicy(section, policyType, oldValues, newValues);
DbContext.SaveChanges();
dbContext.SaveChanges();
transaction.Commit();
}

public async Task UpdatePolicyAsync(string section, string policyType, IPolicyValues oldValues, IPolicyValues newValues)
{
var dbContext = GetOrResolveDbContext();
if (newValues.Count is 0)
{
return;
}
await using var transaction = await DbContext.Database.BeginTransactionAsync();
await using var transaction = await dbContext.Database.BeginTransactionAsync();
await InternalUpdatePolicyAsync(section, policyType, oldValues, newValues);
await DbContext.SaveChangesAsync();
await dbContext.SaveChangesAsync();
await transaction.CommitAsync();
}

public void UpdatePolicies(string section, string policyType, IReadOnlyList<IPolicyValues> oldValuesList, IReadOnlyList<IPolicyValues> newValuesList)
{
var dbContext = GetOrResolveDbContext();
if (newValuesList.Count is 0)
{
return;
}
using var transaction = DbContext.Database.BeginTransaction();
using var transaction = dbContext.Database.BeginTransaction();
InternalUpdatePolicies(section, policyType, oldValuesList, newValuesList);
DbContext.SaveChanges();
dbContext.SaveChanges();
transaction.Commit();
}

public async Task UpdatePoliciesAsync(string section, string policyType, IReadOnlyList<IPolicyValues> oldValuesList, IReadOnlyList<IPolicyValues> newValuesList)
{
var dbContext = GetOrResolveDbContext();
if (newValuesList.Count is 0)
{
return;
}
await using var transaction = await DbContext.Database.BeginTransactionAsync();
await using var transaction = await dbContext.Database.BeginTransactionAsync();
await InternalUpdatePoliciesAsync(section, policyType, oldValuesList, newValuesList);
await DbContext.SaveChangesAsync();
await dbContext.SaveChangesAsync();
await transaction.CommitAsync();
}

Expand Down
Loading
Loading