Skip to content

Commit a785663

Browse files
Allowed the ado.net context to get an inserted primary key value on a single transaction when possible
1 parent a2b88c1 commit a785663

File tree

3 files changed

+164
-65
lines changed

3 files changed

+164
-65
lines changed

src/DotNetToolkit.Repository.AdoNet/Internal/AdoNetRepositoryContext.cs

Lines changed: 92 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ internal class AdoNetRepositoryContext : IRepositoryContextAsync
3939
private readonly ConcurrentDictionary<Type, bool> _schemaValidationTypeMapping = new ConcurrentDictionary<Type, bool>();
4040
private readonly SchemaTableConfigurationHelper _schemaConfigHelper;
4141

42-
private readonly QueryBuilder _queryBuilder = new QueryBuilder();
42+
private readonly QueryBuilder _queryBuilder;
43+
44+
private readonly DataAccessProviderType _providerType;
4345

4446
#endregion
4547

@@ -63,25 +65,14 @@ public AdoNetRepositoryContext(string nameOrConnectionString)
6365
if (nameOrConnectionString == null)
6466
throw new ArgumentNullException(nameof(nameOrConnectionString));
6567

66-
var css = ConfigurationManager.ConnectionStrings[nameOrConnectionString];
67-
if (css == null)
68-
{
69-
for (var i = 0; i < ConfigurationManager.ConnectionStrings.Count; i++)
70-
{
71-
css = ConfigurationManager.ConnectionStrings[i];
72-
73-
if (css.ConnectionString.Equals(nameOrConnectionString))
74-
break;
75-
}
76-
}
77-
78-
if (css == null)
79-
throw new ArgumentException("The connection string does not exist in your configuration file.");
68+
var css = GetConnectionStringSettings(nameOrConnectionString);
8069

8170
_factory = DbProviderFactories.GetFactory(css.ProviderName);
8271
_connectionString = css.ConnectionString;
8372
_ownsConnection = true;
8473
_schemaConfigHelper = new SchemaTableConfigurationHelper(this);
74+
_providerType = DataAccessProvider.GetProviderType(css.ProviderName);
75+
_queryBuilder = new QueryBuilder(_providerType);
8576
}
8677

8778
/// <summary>
@@ -101,6 +92,8 @@ public AdoNetRepositoryContext(string providerName, string connectionString)
10192
_connectionString = connectionString;
10293
_ownsConnection = true;
10394
_schemaConfigHelper = new SchemaTableConfigurationHelper(this);
95+
_providerType = DataAccessProvider.GetProviderType(providerName);
96+
_queryBuilder = new QueryBuilder(_providerType);
10497
}
10598

10699
/// <summary>
@@ -118,6 +111,11 @@ public AdoNetRepositoryContext(DbConnection existingConnection)
118111
_connection = existingConnection;
119112
_ownsConnection = false;
120113
_schemaConfigHelper = new SchemaTableConfigurationHelper(this);
114+
115+
var css = GetConnectionStringSettings(existingConnection.ConnectionString);
116+
117+
_providerType = DataAccessProvider.GetProviderType(css.ProviderName);
118+
_queryBuilder = new QueryBuilder(_providerType);
121119
}
122120

123121
#endregion
@@ -1142,6 +1140,27 @@ protected DbCommand CreateCommand(string cmdText, CommandType cmdType, Dictionar
11421140

11431141
#region Private Methods
11441142

1143+
private static ConnectionStringSettings GetConnectionStringSettings(string nameOrConnectionString)
1144+
{
1145+
var css = ConfigurationManager.ConnectionStrings[nameOrConnectionString];
1146+
1147+
if (css == null)
1148+
{
1149+
for (var i = 0; i < ConfigurationManager.ConnectionStrings.Count; i++)
1150+
{
1151+
css = ConfigurationManager.ConnectionStrings[i];
1152+
1153+
if (css.ConnectionString.Equals(nameOrConnectionString))
1154+
break;
1155+
}
1156+
}
1157+
1158+
if (css == null)
1159+
throw new ArgumentException("The connection string does not exist in your configuration file.");
1160+
1161+
return css;
1162+
}
1163+
11451164
private static T ConvertValue<T>(object value)
11461165
{
11471166
if (value == null || value is DBNull)
@@ -1346,9 +1365,8 @@ public int SaveChanges()
13461365

13471366
OnSchemaValidation(entityType);
13481367

1349-
var primeryKeyPropertyInfo =
1350-
PrimaryKeyConventionHelper.GetPrimaryKeyPropertyInfos(entityType).First();
1351-
var isIdentity = primeryKeyPropertyInfo.IsColumnIdentity();
1368+
var primaryKeyPropertyInfo = PrimaryKeyConventionHelper.GetPrimaryKeyPropertyInfos(entityType).First();
1369+
var isIdentity = primaryKeyPropertyInfo.IsColumnIdentity();
13521370

13531371
// Checks if the entity exist in the database
13541372
var existInDb = command.ExecuteObjectExist(entitySet.Entity);
@@ -1358,7 +1376,7 @@ public int SaveChanges()
13581376
entitySet,
13591377
existInDb,
13601378
isIdentity,
1361-
primeryKeyPropertyInfo,
1379+
primaryKeyPropertyInfo,
13621380
out string sql,
13631381
out Dictionary<string, object> parameters);
13641382

@@ -1368,24 +1386,39 @@ public int SaveChanges()
13681386
command.Parameters.Clear();
13691387
command.AddParameters(parameters);
13701388

1371-
rows += command.ExecuteNonQuery();
1372-
1373-
if (Logger.IsEnabled(LogLevel.Debug))
1374-
Logger.Debug(FormatExecutingDebugQuery("ExecuteNonQuery", parameters, sql));
1375-
1376-
// Checks to see if the model needs to be updated with the new key returned from the database
13771389
if (entitySet.State == EntityState.Added && isIdentity)
13781390
{
1379-
command.CommandText = "SELECT @@IDENTITY";
1380-
command.Parameters.Clear();
1391+
#if NETFULL
1392+
if (_providerType == DataAccessProviderType.SqlServerCompact)
1393+
{
1394+
if (Logger.IsEnabled(LogLevel.Debug))
1395+
Logger.Debug(FormatExecutingDebugQuery("ExecuteNonQuery", parameters, sql));
13811396

1397+
command.ExecuteNonQuery();
1398+
1399+
sql = "SELECT @@IDENTITY";
1400+
parameters.Clear();
1401+
1402+
command.CommandText = sql;
1403+
command.Parameters.Clear();
1404+
}
1405+
#endif
13821406
if (Logger.IsEnabled(LogLevel.Debug))
1383-
Logger.Debug(FormatExecutingDebugQuery("ExecuteScalar", null, command.CommandText));
1407+
Logger.Debug(FormatExecutingDebugQuery("ExecuteScalar", parameters, sql));
13841408

13851409
var newKey = command.ExecuteScalar();
1386-
var convertedKeyValue = Convert.ChangeType(newKey, primeryKeyPropertyInfo.PropertyType);
1410+
var convertedKeyValue = Convert.ChangeType(newKey, primaryKeyPropertyInfo.PropertyType);
1411+
1412+
primaryKeyPropertyInfo.SetValue(entitySet.Entity, convertedKeyValue, null);
13871413

1388-
primeryKeyPropertyInfo.SetValue(entitySet.Entity, convertedKeyValue, null);
1414+
rows++;
1415+
}
1416+
else
1417+
{
1418+
if (Logger.IsEnabled(LogLevel.Debug))
1419+
Logger.Debug(FormatExecutingDebugQuery("ExecuteNonQuery", parameters, sql));
1420+
1421+
rows += command.ExecuteNonQuery();
13891422
}
13901423
}
13911424
}
@@ -1700,8 +1733,8 @@ public QueryResult<IEnumerable<TResult>> GroupBy<TEntity, TGroupKey, TResult>(IQ
17001733

17011734
await OnSchemaValidationAsync(entityType, cancellationToken);
17021735

1703-
var primeryKeyPropertyInfo = PrimaryKeyConventionHelper.GetPrimaryKeyPropertyInfos(entityType).First();
1704-
var isIdentity = primeryKeyPropertyInfo.IsColumnIdentity();
1736+
var primaryKeyPropertyInfo = PrimaryKeyConventionHelper.GetPrimaryKeyPropertyInfos(entityType).First();
1737+
var isIdentity = primaryKeyPropertyInfo.IsColumnIdentity();
17051738

17061739
// Checks if the entity exist in the database
17071740
var existInDb = await command.ExecuteObjectExistAsync(entitySet.Entity, cancellationToken);
@@ -1711,7 +1744,7 @@ public QueryResult<IEnumerable<TResult>> GroupBy<TEntity, TGroupKey, TResult>(IQ
17111744
entitySet,
17121745
existInDb,
17131746
isIdentity,
1714-
primeryKeyPropertyInfo,
1747+
primaryKeyPropertyInfo,
17151748
out string sql,
17161749
out Dictionary<string, object> parameters);
17171750

@@ -1721,24 +1754,39 @@ public QueryResult<IEnumerable<TResult>> GroupBy<TEntity, TGroupKey, TResult>(IQ
17211754
command.Parameters.Clear();
17221755
command.AddParameters(parameters);
17231756

1724-
if (Logger.IsEnabled(LogLevel.Debug))
1725-
Logger.Debug(FormatExecutingDebugQuery("ExecuteNonQueryAsync", parameters, sql));
1726-
1727-
rows += await command.ExecuteNonQueryAsync(cancellationToken);
1728-
1729-
// Checks to see if the model needs to be updated with the new key returned from the database
17301757
if (entitySet.State == EntityState.Added && isIdentity)
17311758
{
1732-
command.CommandText = "SELECT @@IDENTITY";
1733-
command.Parameters.Clear();
1759+
#if NETFULL
1760+
if (_providerType == DataAccessProviderType.SqlServerCompact)
1761+
{
1762+
if (Logger.IsEnabled(LogLevel.Debug))
1763+
Logger.Debug(FormatExecutingDebugQuery("ExecuteNonQueryAsync", parameters, sql));
1764+
1765+
await command.ExecuteNonQueryAsync(cancellationToken);
17341766

1767+
sql = "SELECT @@IDENTITY";
1768+
parameters.Clear();
1769+
1770+
command.CommandText = sql;
1771+
command.Parameters.Clear();
1772+
}
1773+
#endif
17351774
if (Logger.IsEnabled(LogLevel.Debug))
1736-
Logger.Debug(FormatExecutingDebugQuery("ExecuteScalarAsync", null, command.CommandText));
1775+
Logger.Debug(FormatExecutingDebugQuery("ExecuteScalarAsync", parameters, sql));
17371776

17381777
var newKey = await command.ExecuteScalarAsync(cancellationToken);
1739-
var convertedKeyValue = Convert.ChangeType(newKey, primeryKeyPropertyInfo.PropertyType);
1778+
var convertedKeyValue = Convert.ChangeType(newKey, primaryKeyPropertyInfo.PropertyType);
1779+
1780+
primaryKeyPropertyInfo.SetValue(entitySet.Entity, convertedKeyValue, null);
1781+
1782+
rows++;
1783+
}
1784+
else
1785+
{
1786+
if (Logger.IsEnabled(LogLevel.Debug))
1787+
Logger.Debug(FormatExecutingDebugQuery("ExecuteNonQueryAsync", parameters, sql));
17401788

1741-
primeryKeyPropertyInfo.SetValue(entitySet.Entity, convertedKeyValue, null);
1789+
rows += await command.ExecuteNonQueryAsync(cancellationToken);
17421790
}
17431791
}
17441792
}

src/DotNetToolkit.Repository.AdoNet/Internal/DbProviderFactories.cs

Lines changed: 51 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,21 @@ public static DbProviderFactory GetFactory(string providerName)
2727
#if NETFULL
2828
return System.Data.Common.DbProviderFactories.GetFactory(providerName);
2929
#else
30-
var providername = providerName.ToLower();
31-
32-
if (providerName == "system.data.sqlclient")
33-
return GetFactory(DataAccessProviderTypes.SqlServer);
34-
if (providerName == "system.data.sqlite" || providerName == "microsoft.data.sqlite")
35-
return GetFactory(DataAccessProviderTypes.SqLite);
36-
if (providerName == "mysql.data.mysqlclient" || providername == "mysql.data")
37-
return GetFactory(DataAccessProviderTypes.MySql);
38-
if (providerName == "npgsql")
39-
return GetFactory(DataAccessProviderTypes.PostgreSql);
40-
41-
throw new NotSupportedException($"Unsupported Provider Factory specified: {providerName}");
30+
switch (providerName.ToLower())
31+
{
32+
case "system.data.sqlclient":
33+
return GetFactory(DataAccessProviderType.SqlServer);
34+
case "system.data.sqlite":
35+
case "microsoft.data.sqlite":
36+
return GetFactory(DataAccessProviderType.SqLite);
37+
case "mysql.data.mysqlclient":
38+
case "mysql.data":
39+
return GetFactory(DataAccessProviderType.MySql);
40+
case "npgsql":
41+
return GetFactory(DataAccessProviderType.PostgreSql);
42+
default:
43+
throw new NotSupportedException($"Unsupported Provider Factory specified: {providerName}");
44+
}
4245
#endif
4346
}
4447

@@ -155,27 +158,27 @@ private static DbProviderFactory GetFactory(string dbProviderFactoryTypename, st
155158
return instance as DbProviderFactory;
156159
}
157160

158-
private static DbProviderFactory GetFactory(DataAccessProviderTypes type)
161+
private static DbProviderFactory GetFactory(DataAccessProviderType type)
159162
{
160-
if (type == DataAccessProviderTypes.SqlServer)
163+
if (type == DataAccessProviderType.SqlServer)
161164
return SqlClientFactory.Instance; // this library has a ref to SqlClient so this works
162165

163-
if (type == DataAccessProviderTypes.SqLite)
166+
if (type == DataAccessProviderType.SqLite)
164167
{
165168
#if NETFULL
166169
return GetFactory("System.Data.SQLite.SQLiteFactory", "System.Data.SQLite");
167170
#else
168171
return GetFactory("Microsoft.Data.Sqlite.SqliteFactory", "Microsoft.Data.Sqlite");
169172
#endif
170173
}
171-
if (type == DataAccessProviderTypes.MySql)
174+
if (type == DataAccessProviderType.MySql)
172175
return GetFactory("MySql.Data.MySqlClient.MySqlClientFactory", "MySql.Data");
173-
if (type == DataAccessProviderTypes.PostgreSql)
176+
if (type == DataAccessProviderType.PostgreSql)
174177
return GetFactory("Npgsql.NpgsqlFactory", "Npgsql");
175178
#if NETFULL
176-
if (type == DataAccessProviderTypes.OleDb)
179+
if (type == DataAccessProviderType.OleDb)
177180
return System.Data.OleDb.OleDbFactory.Instance;
178-
if (type == DataAccessProviderTypes.SqlServerCompact)
181+
if (type == DataAccessProviderType.SqlServerCompact)
179182
return System.Data.Common.DbProviderFactories.GetFactory("System.Data.SqlServerCe.4.0");
180183
#endif
181184

@@ -185,7 +188,7 @@ private static DbProviderFactory GetFactory(DataAccessProviderTypes type)
185188
#endregion
186189
}
187190

188-
internal enum DataAccessProviderTypes
191+
internal enum DataAccessProviderType
189192
{
190193
SqlServer,
191194
SqLite,
@@ -197,4 +200,32 @@ internal enum DataAccessProviderTypes
197200
SqlServerCompact
198201
#endif
199202
}
203+
204+
internal static class DataAccessProvider
205+
{
206+
public static DataAccessProviderType GetProviderType(string providerName)
207+
{
208+
switch (providerName.ToLower())
209+
{
210+
#if NETFULL
211+
case "system.data.sqlserverce.4.0":
212+
return DataAccessProviderType.SqlServerCompact;
213+
case "microsoft.jet.oledb.4.0":
214+
return DataAccessProviderType.OleDb;
215+
#endif
216+
case "system.data.sqlclient":
217+
return DataAccessProviderType.SqlServer;
218+
case "system.data.sqlite":
219+
case "microsoft.data.sqlite":
220+
return DataAccessProviderType.SqLite;
221+
case "mysql.data.mysqlclient":
222+
case "mysql.data":
223+
return DataAccessProviderType.MySql;
224+
case "npgsql":
225+
return DataAccessProviderType.PostgreSql;
226+
default:
227+
throw new NotSupportedException($"Unsupported Provider Factory specified: {providerName}");
228+
}
229+
}
230+
}
200231
}

src/DotNetToolkit.Repository.AdoNet/Internal/QueryBuilder.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ internal class QueryBuilder
2222

2323
// TODO: NEEDS TO FIGURE OUT A BETTER WAY TO DO THIS
2424
private static readonly Lazy<string> _crossJoinCountColumnName = new Lazy<string>(() => "Counter_" + Guid.NewGuid().ToString("N"));
25-
25+
private readonly DataAccessProviderType _providerType;
26+
2627
#endregion
2728

2829
#region Preperties
@@ -31,6 +32,15 @@ internal class QueryBuilder
3132

3233
#endregion
3334

35+
#region Constructors
36+
37+
public QueryBuilder(DataAccessProviderType providerType)
38+
{
39+
_providerType = providerType;
40+
}
41+
42+
#endregion
43+
3444
#region Public Methods
3545

3646
public void PrepareCountQuery<T>(IQueryOptions<T> options, out Mapper mapper) where T : class
@@ -418,6 +428,16 @@ public void PrepareEntitySetQuery(EntitySet entitySet, bool existInDb, bool isId
418428

419429
sql = $"INSERT INTO [{tableName}] ({columnNames}){Environment.NewLine}VALUES ({values})";
420430

431+
var canGetScopeIdentity = true;
432+
433+
#if NETFULL
434+
if (_providerType == DataAccessProviderType.SqlServerCompact)
435+
canGetScopeIdentity = false;
436+
#endif
437+
438+
if (canGetScopeIdentity)
439+
sql += $"{Environment.NewLine}SELECT SCOPE_IDENTITY()";
440+
421441
foreach (var pi in properties)
422442
{
423443
parameters.Add($"@{pi.Value.GetColumnName()}", pi.Value.GetValue(entitySet.Entity, null));

0 commit comments

Comments
 (0)