Skip to content

Commit ddee3a0

Browse files
Add the UseSqlServer() method to specify the database provider to use.
1 parent 61166fa commit ddee3a0

File tree

10 files changed

+195
-23
lines changed

10 files changed

+195
-23
lines changed

src/Database.Updater/DatabaseUpdaterBuilder.cs

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ public sealed class DatabaseUpdaterBuilder
4040

4141
private readonly IList<string> migrationsAssemblies = new List<string>();
4242

43+
private IDatabaseProvider? databaseProvider;
44+
4345
/// <summary>
4446
/// Initializes a new instance of the <see cref="DatabaseUpdaterBuilder"/> class.
4547
/// </summary>
@@ -82,8 +84,14 @@ public DatabaseUpdaterBuilder UseMigrationsAssembly(Assembly assembly)
8284
/// Builds an instance of the <see cref="IDatabaseUpdater"/> to perform the migration of the database.
8385
/// </summary>
8486
/// <returns>An instance of the <see cref="IDatabaseUpdater"/> to perform the migration of the database.</returns>
87+
/// <exception cref="InvalidOperationException">No database provider has been configured.</exception>
8588
public IDatabaseUpdater Build()
8689
{
90+
if (this.databaseProvider is null)
91+
{
92+
throw new InvalidOperationException("No database provider has been configured.");
93+
}
94+
8795
var rootCommand = new RootCommand($"Upgrade the {this.applicationName} database.")
8896
{
8997
new SqlServerConnectionStringArgument("connection-string")
@@ -113,7 +121,7 @@ public IDatabaseUpdater Build()
113121
migrationsAssemblies.Add(this.callingAssembly.GetName().Name!);
114122
}
115123

116-
var databaseUpdater = new EntityFrameworkDatabaseUpdater(migrationsAssemblies);
124+
var databaseUpdater = new EntityFrameworkDatabaseUpdater(this.databaseProvider, migrationsAssemblies);
117125

118126
rootCommand.Action = CommandHandler.Create<string, int, string, IHost, CancellationToken>(databaseUpdater.UpgradeAsync);
119127

@@ -130,6 +138,13 @@ public IDatabaseUpdater Build()
130138
return new CommandLineDatabaseUpdater(commandLine);
131139
}
132140

141+
internal DatabaseUpdaterBuilder UseDatabaseProvider(IDatabaseProvider databaseProvider)
142+
{
143+
this.databaseProvider = databaseProvider;
144+
145+
return this;
146+
}
147+
133148
private sealed class CommandLineDatabaseUpdater : IDatabaseUpdater
134149
{
135150
private readonly CommandLineConfiguration commandLine;

src/Database.Updater/EntityFrameworkDatabaseUpdater.cs

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,45 +6,34 @@
66

77
namespace PosInformatique.Database.Updater
88
{
9-
using Microsoft.Data.SqlClient;
109
using Microsoft.EntityFrameworkCore;
1110
using Microsoft.Extensions.DependencyInjection;
1211
using Microsoft.Extensions.Hosting;
1312
using Microsoft.Extensions.Logging;
1413

1514
internal sealed class EntityFrameworkDatabaseUpdater
1615
{
16+
private readonly IDatabaseProvider databaseProvider;
17+
1718
private readonly IReadOnlyList<string> migrationsAssemblies;
1819

19-
public EntityFrameworkDatabaseUpdater(IReadOnlyList<string> migrationsAssemblies)
20+
public EntityFrameworkDatabaseUpdater(IDatabaseProvider databaseProvider, IReadOnlyList<string> migrationsAssemblies)
2021
{
22+
this.databaseProvider = databaseProvider;
2123
this.migrationsAssemblies = migrationsAssemblies;
2224
}
2325

24-
public async Task<int> UpgradeAsync(string connectionString, int commandTimeout, string accessToken, IHost host, CancellationToken cancellationToken)
26+
public async Task<int> UpgradeAsync(string connectionString, int commandTimeout, string? accessToken, IHost host, CancellationToken cancellationToken)
2527
{
2628
var loggerFactory = host.Services.GetRequiredService<ILoggerFactory>();
2729
var logger = loggerFactory.CreateLogger<EntityFrameworkDatabaseUpdater>();
2830

29-
var connectionStringBuilder = new SqlConnectionStringBuilder(connectionString);
30-
connectionStringBuilder.CommandTimeout = commandTimeout;
31-
32-
using (var connection = new SqlConnection(connectionStringBuilder.ToString()))
31+
using (var connection = this.databaseProvider.CreateConnection(connectionString, commandTimeout, accessToken))
3332
{
34-
connection.AccessToken = accessToken;
35-
36-
var builder = new DbContextOptionsBuilder<DbContext>();
37-
builder.UseSqlServer(
33+
var builder = this.databaseProvider.CreateDbContextOptionsBuilder(
3834
connection,
39-
opt =>
40-
{
41-
foreach (var assembly in this.migrationsAssemblies)
42-
{
43-
opt.MigrationsAssembly(assembly);
44-
}
45-
46-
opt.CommandTimeout(commandTimeout);
47-
});
35+
this.migrationsAssemblies,
36+
commandTimeout);
4837

4938
builder.UseLoggerFactory(loggerFactory);
5039

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//-----------------------------------------------------------------------
2+
// <copyright file="IDatabaseProvider.cs" company="P.O.S Informatique">
3+
// Copyright (c) P.O.S Informatique. All rights reserved.
4+
// </copyright>
5+
//-----------------------------------------------------------------------
6+
7+
namespace PosInformatique.Database.Updater
8+
{
9+
using System.Data.Common;
10+
using Microsoft.EntityFrameworkCore;
11+
12+
/// <summary>
13+
/// Represents a database provider.
14+
/// </summary>
15+
internal interface IDatabaseProvider
16+
{
17+
/// <summary>
18+
/// Creates a <see cref="DbConnection"/> to the database.
19+
/// </summary>
20+
/// <param name="connectionString">Connection string to the database.</param>
21+
/// <param name="commandTimeout">Timeout for the command execution.</param>
22+
/// <param name="accessToken">Access token for authentication if need.</param>
23+
/// <returns>The <see cref="DbConnection"/> which allows to connect to the database.</returns>
24+
DbConnection CreateConnection(string connectionString, int commandTimeout, string? accessToken);
25+
26+
/// <summary>
27+
/// Creates an instance of the <see cref="DbContextOptionsBuilder"/> to create a <see cref="DbContext"/>
28+
/// which will be used for the Entity Framework migrations.
29+
/// </summary>
30+
/// <param name="connection"><see cref="DbConnection"/> to the database.</param>
31+
/// <param name="migrationsAssemblies">List of the assemblies that contains the migrations to execute.</param>
32+
/// <param name="commandTimeout">Timeout for the command execution.</param>
33+
/// <returns>An instance of the <see cref="DbContextOptionsBuilder"/> to create a <see cref="DbContext"/>
34+
/// which will be used for the Entity Framework migrations.</returns>
35+
DbContextOptionsBuilder CreateDbContextOptionsBuilder(DbConnection connection, IReadOnlyList<string> migrationsAssemblies, int commandTimeout);
36+
}
37+
}
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
//-----------------------------------------------------------------------
2+
// <copyright file="SqlServerDatabaseProvider.cs" company="P.O.S Informatique">
3+
// Copyright (c) P.O.S Informatique. All rights reserved.
4+
// </copyright>
5+
//-----------------------------------------------------------------------
6+
7+
namespace PosInformatique.Database.Updater.SqlServer
8+
{
9+
using System.Data.Common;
10+
using Microsoft.Data.SqlClient;
11+
using Microsoft.EntityFrameworkCore;
12+
13+
internal sealed class SqlServerDatabaseProvider : IDatabaseProvider
14+
{
15+
public SqlServerDatabaseProvider()
16+
{
17+
}
18+
19+
public DbConnection CreateConnection(string connectionString, int commandTimeout, string? accessToken)
20+
{
21+
var connectionStringBuilder = new SqlConnectionStringBuilder(connectionString);
22+
connectionStringBuilder.CommandTimeout = commandTimeout;
23+
24+
return new SqlConnection(connectionStringBuilder.ToString())
25+
{
26+
AccessToken = accessToken,
27+
};
28+
}
29+
30+
public DbContextOptionsBuilder CreateDbContextOptionsBuilder(DbConnection connection, IReadOnlyList<string> migrationsAssemblies, int commandTimeout)
31+
{
32+
return new DbContextOptionsBuilder().UseSqlServer(
33+
connection,
34+
opt =>
35+
{
36+
foreach (var assembly in migrationsAssemblies)
37+
{
38+
opt.MigrationsAssembly(assembly);
39+
}
40+
41+
opt.CommandTimeout(commandTimeout);
42+
});
43+
}
44+
}
45+
}
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//-----------------------------------------------------------------------
2+
// <copyright file="SqlServerDatabaseUpdaterBuilderExtensions.cs" company="P.O.S Informatique">
3+
// Copyright (c) P.O.S Informatique. All rights reserved.
4+
// </copyright>
5+
//-----------------------------------------------------------------------
6+
7+
namespace PosInformatique.Database.Updater
8+
{
9+
using PosInformatique.Database.Updater.SqlServer;
10+
11+
/// <summary>
12+
/// Contains extensions methods for the <see cref="DatabaseUpdaterBuilder"/> class to use SQL Server database provider.
13+
/// </summary>
14+
public static class SqlServerDatabaseUpdaterBuilderExtensions
15+
{
16+
/// <summary>
17+
/// Configures the <see cref="DatabaseUpdaterBuilder"/> to use SQL Server database provider.
18+
/// </summary>
19+
/// <param name="builder"><see cref="DatabaseUpdaterBuilder"/> to configure.</param>
20+
/// <returns>The <paramref name="builder"/> instance to continue the configuration.</returns>
21+
public static DatabaseUpdaterBuilder UseSqlServer(this DatabaseUpdaterBuilder builder)
22+
{
23+
return builder.UseDatabaseProvider(new SqlServerDatabaseProvider());
24+
}
25+
}
26+
}

tests/Database.Updater.IntegrationTests/Program.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ public static async Task Main(string[] args)
1313
var databaseUpdaterBuilder = new DatabaseUpdaterBuilder("MyApplication");
1414

1515
var updater = databaseUpdaterBuilder
16+
.UseSqlServer()
1617
.UseMigrationsAssembly(typeof(MigrationsAssembly.PersonDbContext).Assembly)
1718
.Build();
1819

tests/Database.Updater.Tests/DatabaseUpdaterBuilderTest.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public void Constructor_NoApplicationName()
1313
{
1414
var action = () => new DatabaseUpdaterBuilder(null);
1515

16-
action.Should().Throw<ArgumentNullException>()
16+
action.Should().ThrowExactly<ArgumentNullException>()
1717
.WithParameterName("applicationName")
1818
.WithMessage("Value cannot be null. (Parameter 'applicationName')");
1919
}
@@ -25,9 +25,19 @@ public void Constructor_ApplicationName_EmptyOrWhitespace(string applicationName
2525
{
2626
var action = () => new DatabaseUpdaterBuilder(applicationName);
2727

28-
action.Should().Throw<ArgumentException>()
28+
action.Should().ThrowExactly<ArgumentException>()
2929
.WithParameterName("applicationName")
3030
.WithMessage("The value cannot be an empty string or composed entirely of whitespace. (Parameter 'applicationName')");
3131
}
32+
33+
[Fact]
34+
public void Build_NoDatabaseProvider()
35+
{
36+
var builder = new DatabaseUpdaterBuilder("MyApplication");
37+
38+
builder.Invoking(b => b.Build())
39+
.Should().ThrowExactly<InvalidOperationException>()
40+
.WithMessage("No database provider has been configured.");
41+
}
3242
}
3343
}

tests/Database.Updater.Tests/DatabaseUpdaterTest.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public async Task UpgradeAsync_WithExplicitMigrationsAssembly()
2121
var database = await server.CreateEmptyDatabaseAsync("DatabaseUpdaterTest_UpgradeAsync_WithExplicitMigrationsAssembly");
2222

2323
var databaseUpdaterBuilder = new DatabaseUpdaterBuilder("MyApplication")
24+
.UseSqlServer()
2425
.UseMigrationsAssembly(typeof(MigrationsAssembly.Version1).Assembly);
2526
var databaseUpdater = databaseUpdaterBuilder
2627
.Build();
@@ -50,6 +51,7 @@ public async Task UpgradeAsync_WithErrorMigrationsAssembly()
5051
var database = await server.CreateEmptyDatabaseAsync("DatabaseUpdaterTest_UpgradeAsync_WithErrorMigrationsAssembly");
5152

5253
var databaseUpdaterBuilder = new DatabaseUpdaterBuilder("MyApplication")
54+
.UseSqlServer()
5355
.UseMigrationsAssembly(typeof(MigrationsErrorAssembly.Version1).Assembly);
5456
var databaseUpdater = databaseUpdaterBuilder
5557
.Build();
@@ -64,6 +66,7 @@ public async Task UpgradeAsync_NoArguments()
6466
{
6567
var databaseUpdaterBuilder = new DatabaseUpdaterBuilder("MyApplication");
6668
var databaseUpdater = databaseUpdaterBuilder
69+
.UseSqlServer()
6770
.Build();
6871

6972
var result = await databaseUpdater.UpgradeAsync([]);
@@ -78,6 +81,7 @@ public async Task UpgradeAsync_WrongArguments(params string[] args)
7881
{
7982
var databaseUpdaterBuilder = new DatabaseUpdaterBuilder("MyApplication");
8083
var databaseUpdater = databaseUpdaterBuilder
84+
.UseSqlServer()
8185
.Build();
8286

8387
var result = await databaseUpdater.UpgradeAsync(args);
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
//-----------------------------------------------------------------------
2+
// <copyright file="SqlServerDatabaseUpdaterBuilderExtensionsTest.cs" company="P.O.S Informatique">
3+
// Copyright (c) P.O.S Informatique. All rights reserved.
4+
// </copyright>
5+
//-----------------------------------------------------------------------
6+
7+
namespace PosInformatique.Database.Updater.Tests
8+
{
9+
using PosInformatique.Database.Updater.SqlServer;
10+
11+
public class SqlServerDatabaseUpdaterBuilderExtensionsTest
12+
{
13+
[Fact]
14+
public void Constructor()
15+
{
16+
var builder = new DatabaseUpdaterBuilder("MyApplication");
17+
18+
builder.UseSqlServer().Should().BeSameAs(builder);
19+
20+
builder.GetFieldValue<IDatabaseProvider>("databaseProvider").Should().BeOfType<SqlServerDatabaseProvider>();
21+
}
22+
}
23+
}
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
//-----------------------------------------------------------------------
2+
// <copyright file="TestTools.cs" company="P.O.S Informatique">
3+
// Copyright (c) P.O.S Informatique. All rights reserved.
4+
// </copyright>
5+
//-----------------------------------------------------------------------
6+
7+
namespace PosInformatique.Database.Updater.Tests
8+
{
9+
using System.Reflection;
10+
11+
public static class TestTools
12+
{
13+
public static T GetFieldValue<T>(this object obj, string fieldName)
14+
{
15+
var field = obj.GetType().GetField(fieldName, BindingFlags.NonPublic | BindingFlags.Instance);
16+
17+
var value = field!.GetValue(obj);
18+
19+
return (T)value!;
20+
}
21+
}
22+
}

0 commit comments

Comments
 (0)