Skip to content

Commit 947a0c2

Browse files
committed
Translate nested CASE to simpler COALESCE
1 parent c1cf255 commit 947a0c2

File tree

5 files changed

+199
-1
lines changed

5 files changed

+199
-1
lines changed

src/EntityFramework6.Npgsql/EntityFramework6.Npgsql.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,9 @@
7676
<Compile Include="Properties\AssemblyInfo.cs" />
7777
<Compile Include="Spatial\PostgisDataReader.cs" />
7878
<Compile Include="Spatial\PostgisServices.cs" />
79+
<Compile Include="SqlGenerators\CaseIsNullToCoalesceReducer.cs" />
7980
<Compile Include="SqlGenerators\PendingProjectsNode.cs" />
81+
<Compile Include="SqlGenerators\DbExpressionDeepEqual.cs" />
8082
<Compile Include="SqlGenerators\SqlBaseGenerator.cs" />
8183
<Compile Include="SqlGenerators\SqlDeleteGenerator.cs" />
8284
<Compile Include="SqlGenerators\SqlInsertGenerator.cs" />

src/EntityFramework6.Npgsql/NpgsqlProviderManifest.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ public override ReadOnlyCollection<EdmFunction> GetStoreFunctions()
359359
.ToList()
360360
.AsReadOnly();
361361

362-
static EdmFunction CreateComposableEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo)
362+
internal static EdmFunction CreateComposableEdmFunction([NotNull] MethodInfo method, [NotNull] DbFunctionAttribute dbFunctionInfo)
363363
{
364364
if (method == null)
365365
throw new ArgumentNullException(nameof(method));
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
using System.Collections.Generic;
2+
using System.Data.Entity.Core.Common.CommandTrees;
3+
using System.Data.Entity.Core.Common.CommandTrees.ExpressionBuilder;
4+
using System.Data.Entity.Core.Metadata.Edm;
5+
using System.Linq;
6+
7+
namespace Npgsql.SqlGenerators
8+
{
9+
public class CaseIsNullToCoalesceReducer
10+
{
11+
public static DbFunctionExpression InvokeCoalesceExpression(params DbExpression[] argumentExpressions)
12+
{
13+
var fromClrType = PrimitiveType
14+
.GetEdmPrimitiveTypes()
15+
.FirstOrDefault(t => t.ClrEquivalentType == typeof(string));
16+
17+
int i=0;
18+
var func = EdmFunction.Create(
19+
"coalesce",
20+
"Npgsql",
21+
DataSpace.SSpace,
22+
new EdmFunctionPayload
23+
{
24+
ParameterTypeSemantics = ParameterTypeSemantics.AllowImplicitConversion,
25+
Schema = string.Empty,
26+
IsBuiltIn = true,
27+
IsAggregate = false,
28+
IsFromProviderManifest = true,
29+
StoreFunctionName = "coalesce",
30+
IsComposable = true,
31+
ReturnParameters = new[]
32+
{
33+
FunctionParameter.Create("ReturnType", fromClrType,ParameterMode.ReturnValue)
34+
},
35+
Parameters = argumentExpressions.Select(
36+
x => FunctionParameter.Create(
37+
"p" + (i++).ToString(),fromClrType,ParameterMode.In)).ToList()
38+
},
39+
new List<MetadataProperty>());
40+
41+
return func.Invoke(argumentExpressions);
42+
}
43+
44+
public static DbFunctionExpression UnnestCoalesceInvocations(DbFunctionExpression dbFunctionExpression)
45+
{
46+
var args = new List<DbExpression>();
47+
foreach (var arg in dbFunctionExpression.Arguments)
48+
{
49+
if(arg is DbFunctionExpression funcCall
50+
&& funcCall.Function.NamespaceName=="Npgsql"
51+
&& funcCall.Function.Name=="coalesce")
52+
{
53+
args.AddRange(funcCall.Arguments);
54+
}
55+
else
56+
{
57+
args.Add(arg);
58+
}
59+
}
60+
return InvokeCoalesceExpression(args.ToArray());
61+
}
62+
63+
public static DbExpression TransformCoalesce(DbExpression expression)
64+
{
65+
if (expression is DbCaseExpression case2)
66+
{
67+
return TransformCoalesce(case2);
68+
}
69+
70+
if (expression is DbIsNullExpression nullExp)
71+
{
72+
return TransformCoalesce(nullExp.Argument).IsNull();
73+
}
74+
return expression;
75+
}
76+
77+
public static DbExpression TransformCoalesce(DbCaseExpression expression)
78+
{
79+
expression = DbExpressionBuilder.Case(
80+
expression.When.Select(TransformCoalesce),
81+
expression.Then.Select(TransformCoalesce),
82+
expression.Else);
83+
84+
var lastWhen = expression.When.Count-1;
85+
if (expression.When[lastWhen].ExpressionKind == DbExpressionKind.IsNull)
86+
{
87+
var is_null = expression.When[lastWhen] as DbIsNullExpression;
88+
if (DbExpressionDeepEqual.DeepEqual(is_null.Argument,expression.Else))
89+
{
90+
var coalesceInvocation = InvokeCoalesceExpression(is_null.Argument, expression.Then[lastWhen]);
91+
coalesceInvocation = UnnestCoalesceInvocations(coalesceInvocation);
92+
93+
if (expression.When.Count == 1)
94+
{
95+
return coalesceInvocation;
96+
}
97+
98+
var simplifiendCase = DbExpressionBuilder.Case(
99+
expression.When.Take(lastWhen),
100+
expression.Then.Take(lastWhen),
101+
coalesceInvocation);
102+
103+
return TransformCoalesce(simplifiendCase);
104+
}
105+
return expression;
106+
}
107+
return expression;
108+
}
109+
}
110+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
using System.Data.Entity.Core.Common.CommandTrees;
2+
using System.Linq;
3+
4+
namespace Npgsql.SqlGenerators
5+
{
6+
public class DbExpressionDeepEqual
7+
{
8+
public static bool DeepEqual(DbExpression e1, DbExpression e2)
9+
{
10+
if (e1.Equals(e2)) return true;
11+
if (e1.GetType() != e2.GetType()) return false;
12+
if (!e1.ExpressionKind.Equals(e2.ExpressionKind)) return false;
13+
if (!e1.ResultType.Equals(e2.ResultType)) return false;
14+
15+
if (e1 is DbFunctionExpression f1 && e2 is DbFunctionExpression f2)
16+
{
17+
return DeepEqual(f1,f2);
18+
}
19+
if (e1 is DbConstantExpression c1 && e2 is DbConstantExpression c2)
20+
{
21+
return c1.Value.Equals(c2.Value);
22+
}
23+
if (e1 is DbBinaryExpression b1 && e2 is DbBinaryExpression b2)
24+
{
25+
return DeepEqual(b1,b2);
26+
}
27+
if (e1 is DbUnaryExpression u1 && e2 is DbUnaryExpression u2)
28+
{
29+
return DeepEqual(u1,u2);
30+
}
31+
if (e1 is DbVariableReferenceExpression v1 && e2 is DbVariableReferenceExpression v2)
32+
{
33+
return DeepEqual(v1,v2);
34+
}
35+
36+
return false;
37+
}
38+
39+
private static bool DeepEqual(DbFunctionExpression f1, DbFunctionExpression f2)
40+
{
41+
if (!f1.Function.Name.Equals(f2.Function.Name)) return false;
42+
if (!f1.Function.NamespaceName.Equals(f2.Function.NamespaceName)) return false;
43+
if (!f1.Arguments.Count.Equals(f2.Arguments.Count)) return false;
44+
45+
var argumenst_equals = f1.Arguments
46+
.Zip(f2.Arguments, (a, b) => DeepEqual(a, b))
47+
.All(areEquals => areEquals);
48+
49+
return argumenst_equals;
50+
}
51+
52+
private static bool DeepEqual(DbBinaryExpression b1, DbBinaryExpression b2)
53+
{
54+
if (!DeepEqual(b1.Left,b2.Left)) return false;
55+
if (!DeepEqual(b1.Right,b2.Right)) return false;
56+
57+
return true;
58+
}
59+
60+
private static bool DeepEqual(DbUnaryExpression u1, DbUnaryExpression u2)
61+
{
62+
return DeepEqual(u1.Argument,u2.Argument);
63+
}
64+
65+
private static bool DeepEqual(DbVariableReferenceExpression v1, DbVariableReferenceExpression v2)
66+
{
67+
return DeepEqual(v1.VariableName,v1.VariableName);
68+
}
69+
}
70+
}

src/EntityFramework6.Npgsql/SqlGenerators/SqlBaseGenerator.cs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,16 @@ protected string GetDbType(EdmType edmType)
829829

830830
public override VisitedExpression Visit([NotNull] DbCaseExpression expression)
831831
{
832+
var result = CaseIsNullToCoalesceReducer.TransformCoalesce(expression);
833+
if (result is DbCaseExpression case2)
834+
{
835+
expression = case2;
836+
}
837+
else
838+
{
839+
return result.Accept(this);
840+
}
841+
832842
var caseExpression = new LiteralExpression(" CASE ");
833843
for (var i = 0; i < expression.When.Count && i < expression.Then.Count; ++i)
834844
{
@@ -1191,6 +1201,12 @@ VisitedExpression VisitFunction(EdmFunction function, IList<DbExpression> args,
11911201
throw new NotSupportedException("cast type name argument must be a constant expression.");
11921202

11931203
return new CastExpression(args[0].Accept(this), typeNameExpression.Value.ToString());
1204+
}else if (functionName == "coalesce")
1205+
{
1206+
var coalesceFuncCall = new FunctionExpression("coalesce");
1207+
foreach (var a in args)
1208+
coalesceFuncCall.AddArgument(a.Accept(this));
1209+
return coalesceFuncCall;
11941210
}
11951211
}
11961212

0 commit comments

Comments
 (0)