@@ -14,6 +14,9 @@ namespace ServiceStack.OrmLite
14
14
{
15
15
public abstract partial class SqlExpression < T > : ISqlExpression , IHasUntypedSqlExpression
16
16
{
17
+ private const string TrueLiteral = "(1=1)" ;
18
+ private const string FalseLiteral = "(1=0)" ;
19
+
17
20
protected bool visitedExpressionIsTableColumn = false ;
18
21
protected bool skipParameterizationForThisExpression = false ;
19
22
@@ -421,10 +424,17 @@ protected void AppendToWhere(string condition, Expression predicate)
421
424
422
425
useFieldName = true ;
423
426
sep = " " ;
424
- var newExpr = Visit ( predicate ) . ToString ( ) ;
427
+ var newExpr = WhereExpressionToString ( Visit ( predicate ) ) ;
425
428
AppendToWhere ( condition , newExpr ) ;
426
429
}
427
430
431
+ private static string WhereExpressionToString ( object expression )
432
+ {
433
+ if ( expression is bool )
434
+ return ( bool ) expression ? TrueLiteral : FalseLiteral ;
435
+ return expression . ToString ( ) ;
436
+ }
437
+
428
438
protected void AppendToWhere ( string condition , string sqlExpression )
429
439
{
430
440
whereExpression = string . IsNullOrEmpty ( whereExpression )
@@ -1258,6 +1268,24 @@ protected virtual object VisitBinary(BinaryExpression b)
1258
1268
originalLeft = left = Visit ( b . Left ) ;
1259
1269
originalRight = right = Visit ( b . Right ) ;
1260
1270
1271
+ // Handle "expr = true/false", including with the constant on the left
1272
+
1273
+ if ( operand == "=" || operand == "<>" )
1274
+ {
1275
+ if ( left is bool )
1276
+ {
1277
+ Swap ( ref left , ref right ) ; // Should be safe to swap for equality/inequality checks
1278
+ }
1279
+
1280
+ if ( right is bool && ! IsFieldName ( left ) ) // Don't change anything when "expr" is a column name - then we really want "ColName = 1"
1281
+ {
1282
+ if ( operand == "=" )
1283
+ return ( bool ) right ? left : GetNotValue ( left ) ; // "expr == true" becomes "expr", "expr == false" becomes "not (expr)"
1284
+ if ( operand == "<>" )
1285
+ return ( bool ) right ? GetNotValue ( left ) : left ; // "expr != true" becomes "not (expr)", "expr != false" becomes "expr"
1286
+ }
1287
+ }
1288
+
1261
1289
var leftEnum = left as EnumMemberAccess ;
1262
1290
var rightEnum = right as EnumMemberAccess ;
1263
1291
@@ -1282,7 +1310,8 @@ protected virtual object VisitBinary(BinaryExpression b)
1282
1310
}
1283
1311
else if ( left as PartialSqlString == null && right as PartialSqlString == null )
1284
1312
{
1285
- var result = CachedExpressionCompiler . Evaluate ( b ) ;
1313
+ var evaluatedValue = CachedExpressionCompiler . Evaluate ( b ) ;
1314
+ var result = VisitConstant ( Expression . Constant ( evaluatedValue ) ) ;
1286
1315
return result ;
1287
1316
}
1288
1317
else if ( left as PartialSqlString == null )
@@ -1297,10 +1326,7 @@ protected virtual object VisitBinary(BinaryExpression b)
1297
1326
1298
1327
if ( left . ToString ( ) . Equals ( "null" , StringComparison . OrdinalIgnoreCase ) )
1299
1328
{
1300
- // "null is x" will not work, so swap the operands
1301
- var temp = right ;
1302
- right = left ;
1303
- left = temp ;
1329
+ Swap ( ref left , ref right ) ; // "null is x" will not work, so swap the operands
1304
1330
}
1305
1331
1306
1332
if ( operand == "=" && right . ToString ( ) . Equals ( "null" , StringComparison . OrdinalIgnoreCase ) )
@@ -1320,6 +1346,13 @@ protected virtual object VisitBinary(BinaryExpression b)
1320
1346
}
1321
1347
}
1322
1348
1349
+ private static void Swap ( ref object left , ref object right )
1350
+ {
1351
+ var temp = right ;
1352
+ right = left ;
1353
+ left = temp ;
1354
+ }
1355
+
1323
1356
protected virtual void VisitFilter ( string operand , object originalLeft , object originalRight , ref object left , ref object right )
1324
1357
{
1325
1358
if ( skipParameterizationForThisExpression || visitedExpressionIsTableColumn )
@@ -1441,14 +1474,7 @@ protected virtual object VisitUnary(UnaryExpression u)
1441
1474
{
1442
1475
case ExpressionType . Not :
1443
1476
var o = Visit ( u . Operand ) ;
1444
-
1445
- if ( o as PartialSqlString == null )
1446
- return ! ( ( bool ) o ) ;
1447
-
1448
- if ( IsFieldName ( o ) )
1449
- return new PartialSqlString ( o + "=" + GetQuotedFalseValue ( ) ) ;
1450
-
1451
- return new PartialSqlString ( "NOT (" + o + ")" ) ;
1477
+ return GetNotValue ( o ) ;
1452
1478
case ExpressionType . Convert :
1453
1479
if ( u . Method != null )
1454
1480
{
@@ -1459,6 +1485,17 @@ protected virtual object VisitUnary(UnaryExpression u)
1459
1485
return Visit ( u . Operand ) ;
1460
1486
}
1461
1487
1488
+ private object GetNotValue ( object o )
1489
+ {
1490
+ if ( o as PartialSqlString == null )
1491
+ return ! ( ( bool ) o ) ;
1492
+
1493
+ if ( IsFieldName ( o ) )
1494
+ return new PartialSqlString ( o + "=" + GetQuotedFalseValue ( ) ) ;
1495
+
1496
+ return new PartialSqlString ( "NOT (" + o + ")" ) ;
1497
+ }
1498
+
1462
1499
private bool IsColumnAccess ( MethodCallExpression m )
1463
1500
{
1464
1501
if ( m . Object != null && m . Object as MethodCallExpression != null )
@@ -1783,14 +1820,14 @@ protected string ConvertInExpressionToSql(MethodCallExpression m, object quotedC
1783
1820
var argValue = CachedExpressionCompiler . Evaluate ( m . Arguments [ 1 ] ) ;
1784
1821
1785
1822
if ( argValue == null )
1786
- return "(1=0)" ; // "column IN (NULL)" is always false
1823
+ return FalseLiteral ; // "column IN (NULL)" is always false
1787
1824
1788
1825
var enumerableArg = argValue as IEnumerable ;
1789
1826
if ( enumerableArg != null )
1790
1827
{
1791
1828
var inArgs = Sql . Flatten ( enumerableArg ) ;
1792
1829
if ( inArgs . Count == 0 )
1793
- return "(1=0)" ; // "column IN ([])" is always false
1830
+ return FalseLiteral ; // "column IN ([])" is always false
1794
1831
1795
1832
string sqlIn = CreateInParamSql ( inArgs ) ;
1796
1833
return string . Format ( "{0} {1} ({2})" , quotedColName , m . Method . Name , sqlIn ) ;
0 commit comments