Skip to content

Commit 11d4547

Browse files
committed
(#26) Single Merge
1 parent cd26087 commit 11d4547

File tree

2 files changed

+257
-1
lines changed

2 files changed

+257
-1
lines changed

src/EntityFrameworkCore.SqlServer.SimpleBulks/BulkMerge/BulkMergeBuilder.cs

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@ private string GetDbColumnName(string columnName)
137137

138138
public BulkMergeResult Execute(IEnumerable<T> data)
139139
{
140+
if (data.Count() == 1)
141+
{
142+
return SingleMerge(data.First());
143+
}
144+
140145
if (!_updateColumnNames.Any() && !_insertColumnNames.Any())
141146
{
142147
return new BulkMergeResult();
@@ -268,6 +273,114 @@ public BulkMergeResult Execute(IEnumerable<T> data)
268273
return result;
269274
}
270275

276+
public BulkMergeResult SingleMerge(T data)
277+
{
278+
if (!_updateColumnNames.Any() && !_insertColumnNames.Any())
279+
{
280+
return new BulkMergeResult();
281+
}
282+
283+
bool returnDbGeneratedId = _options.ReturnDbGeneratedId && !string.IsNullOrEmpty(_outputIdColumn) && _insertColumnNames.Any();
284+
285+
var propertyNames = _updateColumnNames.Select(RemoveOperator).ToList();
286+
propertyNames.AddRange(_idColumns);
287+
propertyNames.AddRange(_insertColumnNames);
288+
propertyNames = propertyNames.Distinct().ToList();
289+
290+
var clrTypes = typeof(T).GetProviderClrTypes(propertyNames, _valueConverters);
291+
292+
var mergeStatementBuilder = new StringBuilder();
293+
294+
var joinCondition = string.Join(" and ", _idColumns.Select(x =>
295+
{
296+
string collation = !string.IsNullOrEmpty(_options.Collation) && clrTypes[x] == typeof(string) ?
297+
$" collate {_options.Collation}" : string.Empty;
298+
return $"s.[{x}]{collation} = t.[{GetDbColumnName(x)}]{collation}";
299+
}));
300+
301+
var parameterNames = string.Join(", ", propertyNames.Select(x => "@" + x));
302+
var columnNames = string.Join(", ", propertyNames.Select(x => "[" + x + "]"));
303+
304+
var hint = _options.WithHoldLock ? " WITH (HOLDLOCK)" : string.Empty;
305+
306+
mergeStatementBuilder.AppendLine($"MERGE {_table.SchemaQualifiedTableName}{hint} t");
307+
mergeStatementBuilder.AppendLine($" USING (values ({parameterNames})) s({columnNames}) ");
308+
mergeStatementBuilder.AppendLine($"ON ({joinCondition})");
309+
310+
if (_updateColumnNames.Any())
311+
{
312+
mergeStatementBuilder.AppendLine($"WHEN MATCHED");
313+
mergeStatementBuilder.AppendLine($" THEN UPDATE SET");
314+
mergeStatementBuilder.AppendLine(string.Join("," + Environment.NewLine, _updateColumnNames.Select(x => " " + CreateSetStatement(x, "t", "s"))));
315+
}
316+
317+
if (_insertColumnNames.Any())
318+
{
319+
mergeStatementBuilder.AppendLine($"WHEN NOT MATCHED BY TARGET");
320+
mergeStatementBuilder.AppendLine($" THEN INSERT ({string.Join(", ", _insertColumnNames.Select(x => $"[{GetDbColumnName(x)}]"))})");
321+
mergeStatementBuilder.AppendLine($" VALUES ({string.Join(", ", _insertColumnNames.Select(x => $"s.[{x}]"))})");
322+
}
323+
324+
if (returnDbGeneratedId)
325+
{
326+
mergeStatementBuilder.AppendLine($"OUTPUT $action, inserted.[{GetDbColumnName(_outputIdColumn)}]");
327+
}
328+
else
329+
{
330+
mergeStatementBuilder.AppendLine($"OUTPUT $action");
331+
}
332+
333+
mergeStatementBuilder.AppendLine(";");
334+
335+
_connection.EnsureOpen();
336+
337+
var sqlMergeStatement = mergeStatementBuilder.ToString();
338+
339+
Log($"Begin merging temp table:{Environment.NewLine}{sqlMergeStatement}");
340+
341+
BulkMergeResult result = new();
342+
string outputIdDbColumnName = null;
343+
344+
if (returnDbGeneratedId)
345+
{
346+
outputIdDbColumnName = GetDbColumnName(_outputIdColumn);
347+
}
348+
349+
using (var updateCommand = _connection.CreateTextCommand(_transaction, sqlMergeStatement, _options))
350+
{
351+
data.ToSqlParameters(propertyNames, valueConverters: _valueConverters)
352+
.ForEach(x => updateCommand.Parameters.Add(x));
353+
354+
using var reader = updateCommand.ExecuteReader();
355+
356+
while (reader.Read())
357+
{
358+
var action = reader["$action"] as string;
359+
360+
if (action == "INSERT")
361+
{
362+
if (returnDbGeneratedId)
363+
{
364+
var idProperty = typeof(T).GetProperty(_outputIdColumn);
365+
idProperty.SetValue(data, reader[outputIdDbColumnName]);
366+
}
367+
368+
result.InsertedRows++;
369+
}
370+
else if (action == "UPDATE")
371+
{
372+
result.UpdatedRows++;
373+
}
374+
375+
result.AffectedRows++;
376+
}
377+
}
378+
379+
Log("End merging temp table.");
380+
381+
return result;
382+
}
383+
271384
private string CreateSetStatement(string prop, string leftTable, string rightTable)
272385
{
273386
string sqlOperator = "=";
@@ -294,6 +407,11 @@ private void Log(string message)
294407

295408
public async Task<BulkMergeResult> ExecuteAsync(IEnumerable<T> data, CancellationToken cancellationToken = default)
296409
{
410+
if (data.Count() == 1)
411+
{
412+
return await SingleMergeAsync(data.First(), cancellationToken);
413+
}
414+
297415
if (!_updateColumnNames.Any() && !_insertColumnNames.Any())
298416
{
299417
return new BulkMergeResult();
@@ -424,4 +542,113 @@ public async Task<BulkMergeResult> ExecuteAsync(IEnumerable<T> data, Cancellatio
424542

425543
return result;
426544
}
545+
546+
public async Task<BulkMergeResult> SingleMergeAsync(T data, CancellationToken cancellationToken = default)
547+
{
548+
if (!_updateColumnNames.Any() && !_insertColumnNames.Any())
549+
{
550+
return new BulkMergeResult();
551+
}
552+
553+
bool returnDbGeneratedId = _options.ReturnDbGeneratedId && !string.IsNullOrEmpty(_outputIdColumn) && _insertColumnNames.Any();
554+
555+
var propertyNames = _updateColumnNames.Select(RemoveOperator).ToList();
556+
propertyNames.AddRange(_idColumns);
557+
propertyNames.AddRange(_insertColumnNames);
558+
propertyNames = propertyNames.Distinct().ToList();
559+
560+
var clrTypes = typeof(T).GetProviderClrTypes(propertyNames, _valueConverters);
561+
562+
var mergeStatementBuilder = new StringBuilder();
563+
564+
var joinCondition = string.Join(" and ", _idColumns.Select(x =>
565+
{
566+
string collation = !string.IsNullOrEmpty(_options.Collation) && clrTypes[x] == typeof(string) ?
567+
$" collate {_options.Collation}" : string.Empty;
568+
return $"s.[{x}]{collation} = t.[{GetDbColumnName(x)}]{collation}";
569+
}));
570+
571+
var parameterNames = string.Join(", ", propertyNames.Select(x => "@" + x));
572+
var columnNames = string.Join(", ", propertyNames.Select(x => "[" + x + "]"));
573+
574+
var hint = _options.WithHoldLock ? " WITH (HOLDLOCK)" : string.Empty;
575+
576+
mergeStatementBuilder.AppendLine($"MERGE {_table.SchemaQualifiedTableName}{hint} t");
577+
mergeStatementBuilder.AppendLine($" USING (values ({parameterNames})) s({columnNames}) ");
578+
mergeStatementBuilder.AppendLine($"ON ({joinCondition})");
579+
580+
if (_updateColumnNames.Any())
581+
{
582+
mergeStatementBuilder.AppendLine($"WHEN MATCHED");
583+
mergeStatementBuilder.AppendLine($" THEN UPDATE SET");
584+
mergeStatementBuilder.AppendLine(string.Join("," + Environment.NewLine, _updateColumnNames.Select(x => " " + CreateSetStatement(x, "t", "s"))));
585+
}
586+
587+
if (_insertColumnNames.Any())
588+
{
589+
mergeStatementBuilder.AppendLine($"WHEN NOT MATCHED BY TARGET");
590+
mergeStatementBuilder.AppendLine($" THEN INSERT ({string.Join(", ", _insertColumnNames.Select(x => $"[{GetDbColumnName(x)}]"))})");
591+
mergeStatementBuilder.AppendLine($" VALUES ({string.Join(", ", _insertColumnNames.Select(x => $"s.[{x}]"))})");
592+
}
593+
594+
if (returnDbGeneratedId)
595+
{
596+
mergeStatementBuilder.AppendLine($"OUTPUT $action, inserted.[{GetDbColumnName(_outputIdColumn)}]");
597+
}
598+
else
599+
{
600+
mergeStatementBuilder.AppendLine($"OUTPUT $action");
601+
}
602+
603+
mergeStatementBuilder.AppendLine(";");
604+
605+
await _connection.EnsureOpenAsync(cancellationToken);
606+
607+
var sqlMergeStatement = mergeStatementBuilder.ToString();
608+
609+
Log($"Begin merging temp table:{Environment.NewLine}{sqlMergeStatement}");
610+
611+
BulkMergeResult result = new();
612+
string outputIdDbColumnName = null;
613+
614+
if (returnDbGeneratedId)
615+
{
616+
outputIdDbColumnName = GetDbColumnName(_outputIdColumn);
617+
}
618+
619+
using (var updateCommand = _connection.CreateTextCommand(_transaction, sqlMergeStatement, _options))
620+
{
621+
data.ToSqlParameters(propertyNames, valueConverters: _valueConverters)
622+
.ForEach(x => updateCommand.Parameters.Add(x));
623+
624+
using var reader = await updateCommand.ExecuteReaderAsync(cancellationToken);
625+
626+
while (await reader.ReadAsync(cancellationToken))
627+
{
628+
var action = reader["$action"] as string;
629+
630+
if (action == "INSERT")
631+
{
632+
if (returnDbGeneratedId)
633+
{
634+
var idProperty = typeof(T).GetProperty(_outputIdColumn);
635+
idProperty.SetValue(data, reader[outputIdDbColumnName]);
636+
}
637+
638+
result.InsertedRows++;
639+
}
640+
else if (action == "UPDATE")
641+
{
642+
result.UpdatedRows++;
643+
}
644+
645+
result.AffectedRows++;
646+
}
647+
}
648+
649+
Log("End merging temp table.");
650+
651+
return result;
652+
}
653+
427654
}

src/EntityFrameworkCore.SqlServer.SimpleBulks/Extensions/TypeExtensions.cs

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
using System;
1+
using Microsoft.EntityFrameworkCore.Storage.ValueConversion;
2+
using System;
23
using System.Collections.Generic;
4+
using System.ComponentModel;
5+
using System.Linq;
36

47
namespace EntityFrameworkCore.SqlServer.SimpleBulks.Extensions;
58

@@ -25,4 +28,30 @@ public static string ToSqlType(this Type type)
2528
var sqlType = _mappings.TryGetValue(type, out string value) ? value : "nvarchar(max)";
2629
return sqlType;
2730
}
31+
32+
public static Dictionary<string, Type> GetProviderClrTypes(this Type type, IEnumerable<string> propertyNames, IReadOnlyDictionary<string, ValueConverter> valueConverters)
33+
{
34+
var properties = TypeDescriptor.GetProperties(type);
35+
36+
var updatablePros = new List<PropertyDescriptor>();
37+
foreach (PropertyDescriptor prop in properties)
38+
{
39+
if (propertyNames.Contains(prop.Name))
40+
{
41+
updatablePros.Add(prop);
42+
}
43+
}
44+
45+
return updatablePros.ToDictionary(x => x.Name, x => GetProviderClrType(x, valueConverters));
46+
}
47+
48+
private static Type GetProviderClrType(PropertyDescriptor property, IReadOnlyDictionary<string, ValueConverter> valueConverters)
49+
{
50+
if (valueConverters != null && valueConverters.TryGetValue(property.Name, out var converter))
51+
{
52+
return converter.ProviderClrType;
53+
}
54+
55+
return Nullable.GetUnderlyingType(property.PropertyType) ?? property.PropertyType;
56+
}
2857
}

0 commit comments

Comments
 (0)