Skip to content
This repository was archived by the owner on Dec 24, 2022. It is now read-only.

Commit 3388ebb

Browse files
committed
Merge pull request #509 from shift-evgeny/WhereBooleanSupport
Handle expressions that evaluate to true/false/nulls in Where() expression
2 parents 0496164 + a880b0c commit 3388ebb

File tree

2 files changed

+119
-16
lines changed

2 files changed

+119
-16
lines changed

src/ServiceStack.OrmLite/Expressions/SqlExpression.cs

Lines changed: 53 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ namespace ServiceStack.OrmLite
1414
{
1515
public abstract partial class SqlExpression<T> : ISqlExpression, IHasUntypedSqlExpression
1616
{
17+
private const string TrueLiteral = "(1=1)";
18+
private const string FalseLiteral = "(1=0)";
19+
1720
protected bool visitedExpressionIsTableColumn = false;
1821
protected bool skipParameterizationForThisExpression = false;
1922

@@ -421,10 +424,17 @@ protected void AppendToWhere(string condition, Expression predicate)
421424

422425
useFieldName = true;
423426
sep = " ";
424-
var newExpr = Visit(predicate).ToString();
427+
var newExpr = WhereExpressionToString(Visit(predicate));
425428
AppendToWhere(condition, newExpr);
426429
}
427430

431+
private static string WhereExpressionToString(object expression)
432+
{
433+
if (expression is bool)
434+
return (bool)expression ? TrueLiteral : FalseLiteral;
435+
return expression.ToString();
436+
}
437+
428438
protected void AppendToWhere(string condition, string sqlExpression)
429439
{
430440
whereExpression = string.IsNullOrEmpty(whereExpression)
@@ -1258,6 +1268,24 @@ protected virtual object VisitBinary(BinaryExpression b)
12581268
originalLeft = left = Visit(b.Left);
12591269
originalRight = right = Visit(b.Right);
12601270

1271+
// Handle "expr = true/false", including with the constant on the left
1272+
1273+
if (operand == "=" || operand == "<>")
1274+
{
1275+
if (left is bool)
1276+
{
1277+
Swap(ref left, ref right); // Should be safe to swap for equality/inequality checks
1278+
}
1279+
1280+
if (right is bool && !IsFieldName(left)) // Don't change anything when "expr" is a column name - then we really want "ColName = 1"
1281+
{
1282+
if (operand == "=")
1283+
return (bool)right ? left : GetNotValue(left); // "expr == true" becomes "expr", "expr == false" becomes "not (expr)"
1284+
if (operand == "<>")
1285+
return (bool)right ? GetNotValue(left) : left; // "expr != true" becomes "not (expr)", "expr != false" becomes "expr"
1286+
}
1287+
}
1288+
12611289
var leftEnum = left as EnumMemberAccess;
12621290
var rightEnum = right as EnumMemberAccess;
12631291

@@ -1282,7 +1310,8 @@ protected virtual object VisitBinary(BinaryExpression b)
12821310
}
12831311
else if (left as PartialSqlString == null && right as PartialSqlString == null)
12841312
{
1285-
var result = CachedExpressionCompiler.Evaluate(b);
1313+
var evaluatedValue = CachedExpressionCompiler.Evaluate(b);
1314+
var result = VisitConstant(Expression.Constant(evaluatedValue));
12861315
return result;
12871316
}
12881317
else if (left as PartialSqlString == null)
@@ -1297,10 +1326,7 @@ protected virtual object VisitBinary(BinaryExpression b)
12971326

12981327
if (left.ToString().Equals("null", StringComparison.OrdinalIgnoreCase))
12991328
{
1300-
// "null is x" will not work, so swap the operands
1301-
var temp = right;
1302-
right = left;
1303-
left = temp;
1329+
Swap(ref left, ref right); // "null is x" will not work, so swap the operands
13041330
}
13051331

13061332
if (operand == "=" && right.ToString().Equals("null", StringComparison.OrdinalIgnoreCase))
@@ -1320,6 +1346,13 @@ protected virtual object VisitBinary(BinaryExpression b)
13201346
}
13211347
}
13221348

1349+
private static void Swap(ref object left, ref object right)
1350+
{
1351+
var temp = right;
1352+
right = left;
1353+
left = temp;
1354+
}
1355+
13231356
protected virtual void VisitFilter(string operand, object originalLeft, object originalRight, ref object left, ref object right)
13241357
{
13251358
if (skipParameterizationForThisExpression || visitedExpressionIsTableColumn)
@@ -1441,14 +1474,7 @@ protected virtual object VisitUnary(UnaryExpression u)
14411474
{
14421475
case ExpressionType.Not:
14431476
var o = Visit(u.Operand);
1444-
1445-
if (o as PartialSqlString == null)
1446-
return !((bool)o);
1447-
1448-
if (IsFieldName(o))
1449-
return new PartialSqlString(o + "=" + GetQuotedFalseValue());
1450-
1451-
return new PartialSqlString("NOT (" + o + ")");
1477+
return GetNotValue(o);
14521478
case ExpressionType.Convert:
14531479
if (u.Method != null)
14541480
{
@@ -1459,6 +1485,17 @@ protected virtual object VisitUnary(UnaryExpression u)
14591485
return Visit(u.Operand);
14601486
}
14611487

1488+
private object GetNotValue(object o)
1489+
{
1490+
if (o as PartialSqlString == null)
1491+
return !((bool) o);
1492+
1493+
if (IsFieldName(o))
1494+
return new PartialSqlString(o + "=" + GetQuotedFalseValue());
1495+
1496+
return new PartialSqlString("NOT (" + o + ")");
1497+
}
1498+
14621499
private bool IsColumnAccess(MethodCallExpression m)
14631500
{
14641501
if (m.Object != null && m.Object as MethodCallExpression != null)
@@ -1783,14 +1820,14 @@ protected string ConvertInExpressionToSql(MethodCallExpression m, object quotedC
17831820
var argValue = CachedExpressionCompiler.Evaluate(m.Arguments[1]);
17841821

17851822
if (argValue == null)
1786-
return "(1=0)"; // "column IN (NULL)" is always false
1823+
return FalseLiteral; // "column IN (NULL)" is always false
17871824

17881825
var enumerableArg = argValue as IEnumerable;
17891826
if (enumerableArg != null)
17901827
{
17911828
var inArgs = Sql.Flatten(enumerableArg);
17921829
if (inArgs.Count == 0)
1793-
return "(1=0)"; // "column IN ([])" is always false
1830+
return FalseLiteral; // "column IN ([])" is always false
17941831

17951832
string sqlIn = CreateInParamSql(inArgs);
17961833
return string.Format("{0} {1} ({2})", quotedColName, m.Method.Name, sqlIn);

tests/ServiceStack.OrmLite.Tests/ExpressionVisitorTests.cs

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,72 @@ public void Can_Select_using_int_Array_Contains()
272272
Assert.AreEqual(sql1, sql2);
273273
}
274274

275+
[Test]
276+
public void Can_Select_using_boolean_constant()
277+
{
278+
var q = Db.From<TestType>().Where(x => true);
279+
var target = Db.Select(q);
280+
Assert.AreEqual(4, target.Count);
281+
282+
q = Db.From<TestType>().Where(x => false);
283+
target = Db.Select(q);
284+
Assert.AreEqual(0, target.Count);
285+
}
286+
287+
[Test]
288+
public void Can_Select_using_expression_evaluated_to_constant()
289+
{
290+
var a = 5;
291+
var b = 6;
292+
int? nullableInt = null;
293+
294+
var q = Db.From<TestType>().Where(x => a < b); // "a < b" is evaluated by SqlExpression (not at compile time!) to ConstantExpression (true)
295+
var target = Db.Select(q);
296+
Assert.AreEqual(4, target.Count);
297+
298+
q = Db.From<TestType>().Where(x => x.NullableIntCol == nullableInt); // Expression evaluated to "null" in SqlExpression
299+
target = Db.Select(q);
300+
CollectionAssert.AreEquivalent(new[] { 2 }, target.Select(t => t.Id).ToArray());
301+
302+
q = Db.From<TestType>().Where(x => nullableInt == x.NullableIntCol); // Same with the null on the left
303+
target = Db.Select(q);
304+
CollectionAssert.AreEquivalent(new[] { 2 }, target.Select(t => t.Id).ToArray());
305+
306+
// Expression = or <> true or false
307+
308+
q = Db.From<TestType>().Where(x => x.NullableIntCol.HasValue == 5 < 6); // Evaluated to "true" at compile time: equivalent to "x.NullableIntCol != null"
309+
target = Db.Select(q);
310+
CollectionAssert.AreEquivalent(new[] { 1, 3, 4 }, target.Select(t => t.Id).ToArray());
311+
312+
q = Db.From<TestType>().Where(x => x.NullableIntCol.HasValue == 5 > 6); // Evaluated to "false" at compile time: equivalent to "x.NullableIntCol == null"
313+
target = Db.Select(q);
314+
CollectionAssert.AreEquivalent(new[] { 2 }, target.Select(t => t.Id).ToArray());
315+
316+
q = Db.From<TestType>().Where(x => x.NullableIntCol.HasValue != 5 > 6); // != false
317+
target = Db.Select(q);
318+
CollectionAssert.AreEquivalent(new[] { 1, 3, 4 }, target.Select(t => t.Id).ToArray());
319+
320+
q = Db.From<TestType>().Where(x => x.NullableIntCol.HasValue != 5 < 6); // != true
321+
target = Db.Select(q);
322+
CollectionAssert.AreEquivalent(new[] { 2 }, target.Select(t => t.Id).ToArray());
323+
324+
// Same, but with the constant on the left
325+
326+
q = Db.From<TestType>().Where(x => 5 < 6 == x.NullableIntCol.HasValue);
327+
target = Db.Select(q);
328+
CollectionAssert.AreEquivalent(new[] { 1, 3, 4 }, target.Select(t => t.Id).ToArray());
329+
330+
q = Db.From<TestType>().Where(x => 5 > 6 != x.NullableIntCol.HasValue);
331+
target = Db.Select(q);
332+
CollectionAssert.AreEquivalent(new[] { 1, 3, 4 }, target.Select(t => t.Id).ToArray());
333+
334+
// Same, but with the expression evaluated inside SqlExpression (not at compile time)
335+
336+
q = Db.From<TestType>().Where(x => x.NullableIntCol.HasValue == a < b);
337+
target = Db.Select(q);
338+
CollectionAssert.AreEquivalent(new[] { 1, 3, 4 }, target.Select(t => t.Id).ToArray());
339+
}
340+
275341
private int MethodReturningInt(int val)
276342
{
277343
return val;

0 commit comments

Comments
 (0)