Skip to content
This repository was archived by the owner on Dec 24, 2022. It is now read-only.

Commit 40e11ec

Browse files
committed
SqlExpression: Add support for ConditionalExpression
1 parent 6ea8087 commit 40e11ec

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

src/ServiceStack.OrmLite/Expressions/SqlExpression.cs

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1241,6 +1241,8 @@ protected internal virtual object Visit(Expression exp)
12411241
return VisitMemberInit(exp as MemberInitExpression);
12421242
case ExpressionType.Index:
12431243
return VisitIndexExpression(exp as IndexExpression);
1244+
case ExpressionType.Conditional:
1245+
return VisitConditional(exp as ConditionalExpression);
12441246
default:
12451247
return exp.ToString();
12461248
}
@@ -1496,6 +1498,19 @@ protected bool CheckExpressionForTypes(Expression e, ExpressionType[] types)
14961498
return true;
14971499
}
14981500

1501+
var condExpr = e as ConditionalExpression;
1502+
if (condExpr != null)
1503+
{
1504+
if (CheckExpressionForTypes(condExpr.Test, types))
1505+
return true;
1506+
1507+
if (CheckExpressionForTypes(condExpr.IfTrue, types))
1508+
return true;
1509+
1510+
if (CheckExpressionForTypes(condExpr.IfFalse, types))
1511+
return true;
1512+
}
1513+
14991514
var memberExpr = e as MemberExpression;
15001515
e = memberExpr?.Expression;
15011516
}
@@ -1746,6 +1761,32 @@ protected virtual object VisitIndexExpression(IndexExpression e)
17461761
throw new NotImplementedException("Unknown Expression: " + e);
17471762
}
17481763

1764+
protected virtual object VisitConditional(ConditionalExpression e)
1765+
{
1766+
var test = IsBooleanComparison(e.Test)
1767+
? new PartialSqlString($"{VisitMemberAccess((MemberExpression) e.Test)}={GetQuotedTrueValue()}")
1768+
: Visit(e.Test);
1769+
1770+
if (test is bool)
1771+
{
1772+
if ((bool) test)
1773+
{
1774+
var ifTrue = Visit(e.IfTrue);
1775+
return ifTrue;
1776+
}
1777+
1778+
var ifFalse = Visit(e.IfFalse);
1779+
return ifFalse;
1780+
}
1781+
else
1782+
{
1783+
var ifTrue = Visit(e.IfTrue);
1784+
var ifFalse = Visit(e.IfFalse);
1785+
1786+
return new PartialSqlString($"(CASE WHEN {test} THEN {ifTrue} ELSE {ifFalse} END)");
1787+
}
1788+
}
1789+
17491790
private object GetNotValue(object o)
17501791
{
17511792
if (!(o is PartialSqlString))
@@ -1766,6 +1807,10 @@ protected virtual bool IsColumnAccess(MethodCallExpression m)
17661807
if (methCallExp != null)
17671808
return IsColumnAccess(methCallExp);
17681809

1810+
var condExp = m.Object as ConditionalExpression;
1811+
if (condExp != null)
1812+
return IsParameterAccess(condExp);
1813+
17691814
var exp = m.Object as MemberExpression;
17701815
return IsParameterAccess(exp)
17711816
&& IsJoinedTable(exp.Expression.Type);

tests/ServiceStack.OrmLite.Tests/ExpressionVisitorTests.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,57 @@ public void Can_Where_using_constant_filter()
499499
var q = Db.From<TestType>().Where(filter);//todo: here Where: null is NULL. May be need to change to 1=1 ?
500500
var target = Db.Select(q);
501501
Assert.That(target.Count, Is.EqualTo(4));
502+
}
503+
504+
[Test]
505+
public void Can_Where_using_Conditional_filter()
506+
{
507+
System.Linq.Expressions.Expression<Func<TestType, bool>> filter = x => (x.NullableIntCol == null ? 0 : x.NullableIntCol) == 10;
508+
var q = Db.From<TestType>().Where(filter);
509+
var target = Db.Select(q);
510+
Assert.That(target.Count, Is.EqualTo(1));
511+
}
512+
513+
[Test]
514+
public void Can_Where_using_Bool_Conditional_filter()
515+
{
516+
System.Linq.Expressions.Expression<Func<TestType, bool>> filter = x => (x.BoolCol ? x.NullableIntCol : 0) == 10;
517+
var q = Db.From<TestType>().Where(filter);
518+
var target = Db.Select(q);
519+
Assert.That(target.Count, Is.EqualTo(1));
520+
}
521+
522+
[Test]
523+
public void Can_Where_using_Method_with_Conditional_filter()
524+
{
525+
System.Linq.Expressions.Expression<Func<TestType, bool>> filter = x => (x.TextCol == null ? null : x.TextCol).StartsWith("asdf");
526+
var q = Db.From<TestType>().Where(filter);
527+
var target = Db.Select(q);
528+
Assert.That(target.Count, Is.EqualTo(2));
529+
}
530+
531+
[Test]
532+
public void Can_Where_using_Constant_Conditional_filter()
533+
{
534+
var filterConditional = 10;
535+
System.Linq.Expressions.Expression<Func<TestType, bool>> filter = x => (filterConditional > 50 ? 123456789 : x.NullableIntCol) == 10;
536+
var q = Db.From<TestType>().Where(filter);
537+
Assert.That(q.ToSelectStatement(), Does.Not.Contain("123456789"));
538+
539+
var target = Db.Select(q);
540+
Assert.That(target.Count, Is.EqualTo(1));
541+
}
502542

543+
[Test]
544+
public void Can_Where_using_Bool_Constant_Conditional_filter()
545+
{
546+
var filterConditional = true;
547+
System.Linq.Expressions.Expression<Func<TestType, bool>> filter = x => (filterConditional ? x.NullableIntCol : 123456789) == 10;
548+
var q = Db.From<TestType>().Where(filter);
549+
Assert.That(q.ToSelectStatement(), Does.Not.Contain("123456789"));
550+
551+
var target = Db.Select(q);
552+
Assert.That(target.Count, Is.EqualTo(1));
503553
}
504554

505555
private int MethodReturningInt(int val)

0 commit comments

Comments
 (0)