Skip to content

Commit b1803d4

Browse files
committed
TempTables without owned entities can be used in queries with SplitQuery
1 parent edc9f9b commit b1803d4

File tree

16 files changed

+322
-55
lines changed

16 files changed

+322
-55
lines changed

src/Thinktecture.EntityFrameworkCore.BulkOperations/EntityFrameworkCore/Query/SqlExpressions/TempTableExpression.cs

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ namespace Thinktecture.EntityFrameworkCore.Query.SqlExpressions;
77
/// <summary>
88
/// An expression that represents a temp table.
99
/// </summary>
10-
public sealed class TempTableExpression : TableExpressionBase
10+
public sealed class TempTableExpression : TableExpressionBase, INotNullableSqlExpression
1111
{
1212
/// <summary>
1313
/// The name of the table or view.
@@ -38,12 +38,8 @@ protected override void Print(ExpressionPrinter expressionPrinter)
3838
/// <inheritdoc />
3939
public override bool Equals(object? obj)
4040
{
41-
return ReferenceEquals(this, obj) || Equals(obj as TempTableExpression);
42-
}
43-
44-
private bool Equals(TempTableExpression? tempTableExpression)
45-
{
46-
return base.Equals(tempTableExpression) && string.Equals(Name, tempTableExpression.Name);
41+
// This should be reference equal only.
42+
return obj != null && ReferenceEquals(this, obj);
4743
}
4844

4945
/// <inheritdoc />

src/Thinktecture.EntityFrameworkCore.BulkOperations/Extensions/BulkOperationsRelationalQueryableMethodTranslatingExpressionVisitorExtensions.cs

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,9 @@ public static class BulkOperationsRelationalQueryableMethodTranslatingExpression
2222
/// <param name="visitor">The visitor.</param>
2323
/// <param name="methodCallExpression">Method call to translate.</param>
2424
/// <param name="typeMappingSource">Type mapping source.</param>
25-
/// <param name="queryCompilationContext"></param>
25+
/// <param name="queryCompilationContext">Query compilation context.</param>
2626
/// <param name="tempTableQueryContextFactory"></param>
27+
/// <param name="sqlExpressionFactory">SQL expression factory.</param>
2728
/// <returns>Translated method call if a custom method is found; otherwise <c>null</c>.</returns>
2829
/// <exception cref="ArgumentNullException">
2930
/// <paramref name="visitor"/> or <paramref name="methodCallExpression"/> is <c>null</c>.
@@ -33,7 +34,8 @@ public static class BulkOperationsRelationalQueryableMethodTranslatingExpression
3334
MethodCallExpression methodCallExpression,
3435
IRelationalTypeMappingSource typeMappingSource,
3536
QueryCompilationContext queryCompilationContext,
36-
TempTableQueryContextFactory tempTableQueryContextFactory)
37+
TempTableQueryContextFactory tempTableQueryContextFactory,
38+
ISqlExpressionFactory sqlExpressionFactory)
3739
{
3840
ArgumentNullException.ThrowIfNull(visitor);
3941
ArgumentNullException.ThrowIfNull(methodCallExpression);
@@ -49,7 +51,16 @@ public static class BulkOperationsRelationalQueryableMethodTranslatingExpression
4951
if (methodCallExpression.Method.DeclaringType == typeof(BulkOperationsDbSetExtensions))
5052
{
5153
if (methodCallExpression.Method.Name == nameof(BulkOperationsDbSetExtensions.FromTempTable))
52-
return TranslateFromTempTable(GetShapedQueryExpression(visitor, methodCallExpression), methodCallExpression, queryCompilationContext, tempTableQueryContextFactory);
54+
{
55+
var tempTableInfo = ((TempTableInfoExpression)methodCallExpression.Arguments[1]).Value;
56+
57+
if (!tempTableInfo.HasOwnedEntities)
58+
return CreateShapedQueryExpressionForTempTable(sqlExpressionFactory, tempTableInfo);
59+
60+
var shapedQueryExpression = GetShapedQueryExpression(visitor, methodCallExpression);
61+
62+
return TranslateFromTempTable(shapedQueryExpression, tempTableInfo, queryCompilationContext, tempTableQueryContextFactory);
63+
}
5364

5465
throw new InvalidOperationException(CoreStrings.TranslationFailed(methodCallExpression.Print()));
5566
}
@@ -59,12 +70,13 @@ public static class BulkOperationsRelationalQueryableMethodTranslatingExpression
5970

6071
private static Expression TranslateFromTempTable(
6172
ShapedQueryExpression shapedQueryExpression,
62-
MethodCallExpression methodCallExpression,
73+
TempTableInfo tempTableInfo,
6374
QueryCompilationContext queryCompilationContext,
6475
TempTableQueryContextFactory tempTableQueryContextFactory)
6576
{
66-
var tableExpression = (TableExpression)((SelectExpression)shapedQueryExpression.QueryExpression).Tables.Single();
67-
var tempTableName = ((TempTableNameExpression)methodCallExpression.Arguments[1]).Value;
77+
var tempTableName = tempTableInfo.Name ?? throw new Exception("No temp table name provided.");
78+
var selectExpression = (SelectExpression)shapedQueryExpression.QueryExpression;
79+
var tableExpression = (TableExpression)selectExpression.Tables.Single();
6880

6981
var ctx = tempTableQueryContextFactory.Create(tableExpression, tempTableName ?? throw new Exception("No temp table name provided."));
7082
var extractor = Expression.Lambda<Func<QueryContext, TempTableQueryContext>>(Expression.Constant(ctx), QueryCompilationContext.QueryContextParameter);
@@ -74,6 +86,18 @@ private static Expression TranslateFromTempTable(
7486
return shapedQueryExpression;
7587
}
7688

89+
private static Expression CreateShapedQueryExpressionForTempTable(ISqlExpressionFactory sqlExpressionFactory, TempTableInfo tempTableInfo)
90+
{
91+
var tempTableName = tempTableInfo.Name ?? throw new Exception("No temp table name provided.");
92+
var tempTableExpression = new TempTableExpression(tempTableName, "#");
93+
var selectExpression = sqlExpressionFactory.Select(tempTableInfo.EntityType, tempTableExpression);
94+
95+
return new ShapedQueryExpression(selectExpression,
96+
new RelationalEntityShaperExpression(tempTableInfo.EntityType,
97+
new ProjectionBindingExpression(selectExpression, new ProjectionMember(), typeof(ValueBuffer)),
98+
false));
99+
}
100+
77101
private static Expression TranslateBulkDelete(ShapedQueryExpression shapedQueryExpression, IRelationalTypeMappingSource typeMappingSource)
78102
{
79103
var selectExpression = (SelectExpression)shapedQueryExpression.QueryExpression;

src/Thinktecture.EntityFrameworkCore.BulkOperations/Extensions/Internal/BulkOperationsDbSetExtensions.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ public static class BulkOperationsDbSetExtensions
1515
/// <summary>
1616
/// This is an internal API.
1717
/// </summary>
18-
public static IQueryable<T> FromTempTable<T>(this IQueryable<T> source, string name)
18+
public static IQueryable<T> FromTempTable<T>(
19+
this IQueryable<T> source,
20+
TempTableInfo info)
1921
{
2022
ArgumentNullException.ThrowIfNull(source);
21-
ArgumentNullException.ThrowIfNull(name);
23+
ArgumentNullException.ThrowIfNull(info);
2224

2325
var methodInfo = _fromTempTable.MakeGenericMethod(typeof(T));
24-
var expression = Expression.Call(null, methodInfo, source.Expression, new TempTableNameExpression(name));
26+
var expression = Expression.Call(null, methodInfo, source.Expression, new TempTableInfoExpression(info));
2527

2628
return source.Provider.CreateQuery<T>(expression);
2729
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using Microsoft.EntityFrameworkCore.Metadata;
2+
3+
namespace Thinktecture.Internal;
4+
5+
/// <summary>
6+
/// Temp table infos.
7+
/// </summary>
8+
/// <param name="Name">The name of the temp table.</param>
9+
/// <param name="HasOwnedEntities">Indication whether the temp table has owned entities.</param>
10+
/// <param name="EntityType">The entity type of the temp table.</param>
11+
public record TempTableInfo(string Name, bool HasOwnedEntities, IEntityType EntityType);
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
namespace Thinktecture.Internal;
22

33
/// <summary>
4-
/// The name of the temp table.
4+
/// Temp table infos.
55
/// </summary>
6-
public class TempTableNameExpression : NonEvaluatableConstantExpression<string>
6+
public class TempTableInfoExpression : NonEvaluatableConstantExpression<TempTableInfo>
77
{
88
/// <inheritdoc />
9-
public TempTableNameExpression(string value)
9+
public TempTableInfoExpression(TempTableInfo value)
1010
: base(value)
1111
{
1212
}
1313

1414
/// <inheritdoc />
15-
public override bool Equals(string? otherTempTableName)
15+
public override bool Equals(TempTableInfo? otherTempTableName)
1616
{
1717
return Value.Equals(otherTempTableName);
1818
}

src/Thinktecture.EntityFrameworkCore.Relational/EntityFrameworkCore/Query/SqlExpressions/TableWithHintsExpression.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System.Linq.Expressions;
12
using Microsoft.EntityFrameworkCore.Query;
23
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
34

@@ -6,7 +7,7 @@ namespace Thinktecture.EntityFrameworkCore.Query.SqlExpressions;
67
/// <summary>
78
/// An expression that represents table hints.
89
/// </summary>
9-
public sealed class TableWithHintsExpression : TableExpressionBase
10+
public sealed class TableWithHintsExpression : TableExpressionBase, INotNullableSqlExpression
1011
{
1112
/// <summary>
1213
/// Table to apply hints to.
@@ -30,6 +31,16 @@ public TableWithHintsExpression(TableExpressionBase table, IReadOnlyList<ITableH
3031
TableHints = tableHints;
3132
}
3233

34+
/// <inheritdoc />
35+
protected override Expression VisitChildren(ExpressionVisitor visitor)
36+
{
37+
var visited = (TableExpressionBase)visitor.Visit(Table);
38+
39+
return Table != visited
40+
? new TableWithHintsExpression(visited, TableHints)
41+
: this;
42+
}
43+
3344
/// <inheritdoc />
3445
protected override void Print(ExpressionPrinter expressionPrinter)
3546
{

src/Thinktecture.EntityFrameworkCore.Relational/EntityFrameworkCore/Query/ThinktectureSqlNullabilityProcessor.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,15 @@ public ThinktectureSqlNullabilityProcessor(
1717
{
1818
}
1919

20+
/// <inheritdoc />
21+
protected override TableExpressionBase Visit(TableExpressionBase tableExpressionBase)
22+
{
23+
if (tableExpressionBase is INotNullableSqlExpression)
24+
return tableExpressionBase;
25+
26+
return base.Visit(tableExpressionBase);
27+
}
28+
2029
/// <inheritdoc />
2130
protected override SqlExpression VisitCustomSqlExpression(SqlExpression sqlExpression, bool allowOptimizedExpansion, out bool nullable)
2231
{
@@ -28,4 +37,4 @@ protected override SqlExpression VisitCustomSqlExpression(SqlExpression sqlExpre
2837

2938
return base.VisitCustomSqlExpression(sqlExpression, allowOptimizedExpansion, out nullable);
3039
}
31-
}
40+
}

src/Thinktecture.EntityFrameworkCore.Relational/Extensions/RelationalQueryableMethodTranslatingExpressionVisitorExtensions.cs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,20 @@ private static Expression TranslateTableHints(
6060
QueryCompilationContext queryCompilationContext,
6161
TableHintContextFactory tableHintContextFactory)
6262
{
63-
var tableHints = (TableHintsExpression)methodCallExpression.Arguments[1] ?? throw new InvalidOperationException("Table hints cannot be null.");
63+
var tableHintsExpression = (TableHintsExpression)methodCallExpression.Arguments[1] ?? throw new InvalidOperationException("Table hints cannot be null.");
6464
var tables = ((SelectExpression)shapedQueryExpression.QueryExpression).Tables;
6565

6666
if (tables.Count == 0)
67-
throw new InvalidOperationException($"No tables found to apply table hints '{String.Join(", ", tableHints)}' to.");
67+
throw new InvalidOperationException($"No tables found to apply table hints '{String.Join(", ", tableHintsExpression)}' to.");
6868

6969
if (tables.Count > 1)
70-
throw new InvalidOperationException($"Multiple tables found to apply table hints '{String.Join(", ", tableHints)}' to. Expression: {String.Join(", ", tables.Select(t => t.Print()))}");
70+
throw new InvalidOperationException($"Multiple tables found to apply table hints '{String.Join(", ", tableHintsExpression)}' to. Expression: {String.Join(", ", tables.Select(t => t.Print()))}");
7171

72-
var tableExpression = ((SelectExpression)shapedQueryExpression.QueryExpression).Tables[0];
72+
var selectExpression = (SelectExpression)shapedQueryExpression.QueryExpression;
73+
var tableExpression = selectExpression.Tables[0];
74+
var tableHints = tableHintsExpression.Value ?? throw new Exception("No table hints provided.");
7375

74-
var ctx = tableHintContextFactory.Create(tableExpression, tableHints.Value ?? throw new Exception("No table hints provided."));
76+
var ctx = tableHintContextFactory.Create(tableExpression, tableHints);
7577
var extractor = Expression.Lambda<Func<QueryContext, TableHintContext>>(Expression.Constant(ctx), QueryCompilationContext.QueryContextParameter);
7678

7779
queryCompilationContext.RegisterRuntimeParameter(ctx.ParameterName, extractor);

src/Thinktecture.EntityFrameworkCore.SqlServer/EntityFrameworkCore/BulkOperations/SqlServerBulkOperationExecutor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ private async Task<ITempTableQuery<T>> BulkInsertIntoTempTableAsync<T, TEntity>(
327327
? _ctx.Set<TEntity>(entityTypeName)
328328
: _ctx.Set<TEntity>();
329329

330-
var query = dbSet.FromTempTable(tempTableReference.Name);
330+
var query = dbSet.FromTempTable(new TempTableInfo(tempTableReference.Name, selectedProperties.Any(p => p.Navigations.Count != 0), entityType));
331331

332332
var pk = entityType.FindPrimaryKey();
333333

src/Thinktecture.EntityFrameworkCore.SqlServer/EntityFrameworkCore/Query/ThinktectureSqlServerQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ protected override QueryableMethodTranslatingExpressionVisitor CreateSubqueryVis
5454
protected override Expression VisitMethodCall(MethodCallExpression methodCallExpression)
5555
{
5656
return this.TranslateRelationalMethods(methodCallExpression, QueryCompilationContext, _tableHintContextFactory) ??
57-
this.TranslateBulkMethods(methodCallExpression, _typeMappingSource, QueryCompilationContext, _tempTableQueryContextFactory) ??
57+
this.TranslateBulkMethods(methodCallExpression, _typeMappingSource, QueryCompilationContext, _tempTableQueryContextFactory, RelationalDependencies.SqlExpressionFactory) ??
5858
base.VisitMethodCall(methodCallExpression);
5959
}
6060
}

0 commit comments

Comments
 (0)