Skip to content
This repository was archived by the owner on Dec 24, 2022. It is now read-only.

Commit 11aab39

Browse files
committed
Add support for parameterized statements for Oracle
SqlExpression visting logic can now convert (parameterize) parts of expression to sql placeholder, with corresponding parameter whose value is set to the unquoted value derived from expression. i tried to put the parameterization logic on top of existing code, to avoid changing any existing logic. the only place where i did change existing code is here in SqlExpression.VisitBinary: if (operand == "AND" || operand == "OR") { ... if (left as PartialSqlString == null && right as PartialSqlString == null) { var result = Expression.Lambda(b).Compile().DynamicInvoke(); return result; } ... where the result object is returned without being cast to a PartialSqlString. this is to avoid having the result being quoted twice when parameterization is disabled. this change did not seem to cause any unit test failure. parameterization is disabled by default; OracleOrmLiteDialectProvider constructor accepts input to enable the parameterization. updated as many OrmLite apis as was necessary to get Oracle unit tests to pass, using overloads and optional arguments to maintain backward compatibility. did not update any JoinSqlBuilder api as that class is marked obsolete, so there is a JoinSqlBuilderTest unit test that fails with "not all variables bound" error when parameterization is enabled.
1 parent aa83d74 commit 11aab39

15 files changed

+272
-43
lines changed

src/ServiceStack.OrmLite.Oracle.Tests/ServiceStack.OrmLite.Oracle.Tests.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@
148148
<Compile Include="..\..\tests\ServiceStack.OrmLite.Tests\Expression\SelectExpressionTests.cs">
149149
<Link>Expression\SelectExpressionTests.cs</Link>
150150
</Compile>
151+
<Compile Include="..\..\tests\ServiceStack.OrmLite.Tests\Expression\SqlExpressionParamTests.cs">
152+
<Link>Expression\SqlExpressionParamTests.cs</Link>
153+
</Compile>
151154
<Compile Include="..\..\tests\ServiceStack.OrmLite.Tests\Expression\SqlExpressionTests.cs">
152155
<Link>Expression\SqlExpressionTests.cs</Link>
153156
</Compile>

src/ServiceStack.OrmLite.Oracle/OracleOrmLiteDialectProvider.cs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,18 @@ public static OracleOrmLiteDialectProvider Instance
5959
private readonly DbProviderFactory _factory;
6060
private readonly OracleTimestampConverter _timestampConverter;
6161

62+
public bool ParameterizeStatement;
63+
6264
public OracleOrmLiteDialectProvider()
6365
: this(false, false)
6466
{
6567
}
6668

67-
public OracleOrmLiteDialectProvider(bool compactGuid, bool quoteNames, string clientProvider = OdpProvider)
69+
public OracleOrmLiteDialectProvider(bool compactGuid, bool quoteNames, string clientProvider = OdpProvider, bool parameterizeStatement = false)
6870
{
6971
ClientProvider = clientProvider;
7072
CompactGuid = compactGuid;
73+
ParameterizeStatement = parameterizeStatement;
7174
QuoteNames = quoteNames;
7275
BoolColumnDefinition = "NUMBER(1)";
7376
GuidColumnDefinition = CompactGuid ? CompactGuidDefinition : StringGuidDefinition;
@@ -298,6 +301,73 @@ public override string GetQuotedValue(object value, Type fieldType)
298301
return base.GetQuotedValue(value, fieldType);
299302
}
300303

304+
public override object GetValue(object value, Type fieldType)
305+
{
306+
if (!ParameterizeStatement)
307+
return GetQuotedValue(value, fieldType);
308+
309+
if (value == null) return DBNull.Value;
310+
311+
if (fieldType == typeof(Guid))
312+
{
313+
var guid = (Guid)value;
314+
315+
if (CompactGuid)
316+
return guid.ToByteArray();
317+
318+
return guid.ToString();
319+
}
320+
321+
if (fieldType == typeof(DateTimeOffset) || fieldType == typeof(DateTimeOffset?))
322+
{
323+
return GetQuotedDateTimeOffsetValue((DateTimeOffset)value);
324+
}
325+
326+
if ((value is TimeSpan) && (fieldType == typeof(Int64) || fieldType == typeof(Int64?)))
327+
{
328+
var longValue = ((TimeSpan)value).Ticks;
329+
return base.GetQuotedValue(longValue, fieldType);
330+
}
331+
332+
if (fieldType == typeof(TimeSpan))
333+
return ((TimeSpan)value).Ticks;
334+
335+
if (fieldType == typeof(bool?) || fieldType == typeof(bool))
336+
{
337+
var boolValue = (bool)value;
338+
return boolValue ? 1 : 0;
339+
}
340+
341+
if (fieldType.IsEnum)
342+
{
343+
if (value is int && !fieldType.IsEnumFlags())
344+
{
345+
value = fieldType.GetEnumName(value);
346+
}
347+
348+
var enumValue = StringSerializer.SerializeToString(value);
349+
// Oracle stores empty strings in varchar columns as null so match that behavior here
350+
if (enumValue == null)
351+
return null;
352+
enumValue = enumValue.Trim('"');
353+
return enumValue == ""
354+
? "null"
355+
: enumValue;
356+
}
357+
358+
if (fieldType == typeof(byte[]))
359+
{
360+
return "hextoraw('" + BitConverter.ToString((byte[])value).Replace("-", "") + "')";
361+
}
362+
363+
if (fieldType.IsRefType())
364+
{
365+
return StringSerializer.SerializeToString(value);
366+
}
367+
368+
return value;
369+
}
370+
301371
const string IsoDateFormat = "yyyy-MM-dd";
302372
const string IsoTimeFormat = "HH:mm:ss";
303373
const string IsoMillisecondFormat = "fffffff";

src/ServiceStack.OrmLite.Oracle/OracleSqlExpression.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,20 @@ protected override object VisitColumnAccessMethod(MethodCallExpression m)
3131
}
3232
return base.VisitColumnAccessMethod(m);
3333
}
34+
35+
protected override void ConvertToPlaceholderAndParameter(ref object right, Expression rightExpression)
36+
{
37+
if (!((OracleOrmLiteDialectProvider)DialectProvider).ParameterizeStatement)
38+
return;
39+
40+
var paramName = Params.Count.ToString();
41+
var paramValue = right;
42+
43+
var parameter = CreateParam(paramName, paramValue);
44+
Params.Add(parameter);
45+
46+
right = parameter.ParameterName;
47+
}
3448
}
3549
}
3650

src/ServiceStack.OrmLite/Expressions/ReadExpressionCommandExtensions.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@ internal static List<T> Select<T>(this IDbCommand dbCmd, Func<SqlExpression<T>,
1414
var expr = dbCmd.GetDialectProvider().SqlExpression<T>();
1515
string sql = expression(expr).SelectInto<T>();
1616

17-
return dbCmd.ExprConvertToList<T>(sql);
17+
return dbCmd.ExprConvertToList<T>(sql, expr.Params);
1818
}
1919

2020
internal static List<Into> Select<Into, From>(this IDbCommand dbCmd, Func<SqlExpression<From>, SqlExpression<From>> expression)
2121
{
2222
var expr = dbCmd.GetDialectProvider().SqlExpression<From>();
2323
string sql = expression(expr).SelectInto<Into>();
2424

25-
return dbCmd.ExprConvertToList<Into>(sql);
25+
return dbCmd.ExprConvertToList<Into>(sql, expr.Params);
2626
}
2727

2828
internal static List<Into> Select<Into, From>(this IDbCommand dbCmd, SqlExpression<From> expression)
2929
{
3030
string sql = expression.SelectInto<Into>();
31-
return dbCmd.ExprConvertToList<Into>(sql);
31+
return dbCmd.ExprConvertToList<Into>(sql, expression.Params);
3232
}
3333

3434
internal static List<T> Select<T>(this IDbCommand dbCmd, SqlExpression<T> expression)
@@ -43,7 +43,7 @@ internal static List<T> Select<T>(this IDbCommand dbCmd, Expression<Func<T, bool
4343
var expr = dbCmd.GetDialectProvider().SqlExpression<T>();
4444
string sql = expr.Where(predicate).SelectInto<T>();
4545

46-
return dbCmd.ExprConvertToList<T>(sql);
46+
return dbCmd.ExprConvertToList<T>(sql, expr.Params);
4747
}
4848

4949
internal static T Single<T>(this IDbCommand dbCmd, Func<SqlExpression<T>, SqlExpression<T>> expression)
@@ -63,7 +63,7 @@ internal static T Single<T>(this IDbCommand dbCmd, SqlExpression<T> expression)
6363
{
6464
string sql = expression.Limit(1).SelectInto<T>();
6565

66-
return dbCmd.ExprConvertTo<T>(sql);
66+
return dbCmd.ExprConvertTo<T>(sql, expression.Params);
6767
}
6868

6969
public static TKey Scalar<T, TKey>(this IDbCommand dbCmd, Expression<Func<T, TKey>> field)
@@ -80,7 +80,7 @@ internal static TKey Scalar<T, TKey>(this IDbCommand dbCmd,
8080
var ev = dbCmd.GetDialectProvider().SqlExpression<T>();
8181
ev.Select(field).Where(predicate);
8282
string sql = ev.SelectInto<T>();
83-
return dbCmd.Scalar<TKey>(sql);
83+
return dbCmd.Scalar<TKey>(ev.Params, sql);
8484
}
8585

8686
internal static long Count<T>(this IDbCommand dbCmd)
@@ -108,12 +108,12 @@ internal static long Count<T>(this IDbCommand dbCmd, Expression<Func<T, bool>> p
108108
var ev = dbCmd.GetDialectProvider().SqlExpression<T>();
109109
ev.Where(predicate);
110110
var sql = ev.ToCountStatement();
111-
return GetCount(dbCmd, sql);
111+
return GetCount(dbCmd, sql, ev.Params);
112112
}
113113

114-
internal static long GetCount(this IDbCommand dbCmd, string sql)
114+
internal static long GetCount(this IDbCommand dbCmd, string sql, IEnumerable<IDbDataParameter> sqlParams = null)
115115
{
116-
return dbCmd.Column<long>(sql).Sum();
116+
return dbCmd.Column<long>(sqlParams, sql).Sum();
117117
}
118118

119119
internal static long RowCount<T>(this IDbCommand dbCmd, SqlExpression<T> expression)

src/ServiceStack.OrmLite/Expressions/SqlExpression.Join.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ private SqlExpression<T> InternalJoin<Source, Target>(string joinType,
8686

8787
private string InternalCreateSqlFromExpression(Expression joinExpr, bool isCrossJoin)
8888
{
89-
return "{0} {1}".Fmt((isCrossJoin ? "WHERE" : "ON"), Visit(joinExpr).ToString());
89+
return "{0} {1}".Fmt((isCrossJoin ? "WHERE" : "ON"), VisitJoin(joinExpr).ToString());
9090
}
9191

9292
private string InternalCreateSqlFromDefinitions(ModelDefinition sourceDef, ModelDefinition targetDef, bool isCrossJoin)

src/ServiceStack.OrmLite/Expressions/SqlExpression.cs

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ public abstract partial class SqlExpression<T> : ISqlExpression, IHasUntypedSqlE
2828
protected bool useFieldName = false;
2929
protected bool selectDistinct = false;
3030
protected bool CustomSelect { get; set; }
31+
protected bool SkipParameterizationForThisExpression { get; set; }
32+
private bool visitedExpressionIsTableColumn = false;
3133
private ModelDefinition modelDef;
3234
public bool PrefixFieldWithTableName { get; set; }
3335
public bool WhereStatementWithoutWhereString { get; set; }
@@ -44,6 +46,7 @@ public SqlExpression(IOrmLiteDialectProvider dialectProvider)
4446
modelDef = typeof(T).GetModelDefinition();
4547
PrefixFieldWithTableName = false;
4648
WhereStatementWithoutWhereString = false;
49+
SkipParameterizationForThisExpression = false;
4750
DialectProvider = dialectProvider;
4851
Params = new List<IDbDataParameter>();
4952
tableDefs.Add(modelDef);
@@ -927,6 +930,7 @@ protected internal bool UseFieldName
927930

928931
protected internal virtual object Visit(Expression exp)
929932
{
933+
visitedExpressionIsTableColumn = false;
930934

931935
if (exp == null) return string.Empty;
932936
switch (exp.NodeType)
@@ -987,6 +991,14 @@ protected internal virtual object Visit(Expression exp)
987991
}
988992
}
989993

994+
protected internal virtual object VisitJoin(Expression exp)
995+
{
996+
SkipParameterizationForThisExpression = true;
997+
var visitedExpression = Visit(exp);
998+
SkipParameterizationForThisExpression = false;
999+
return visitedExpression;
1000+
}
1001+
9901002
protected virtual object VisitLambda(LambdaExpression lambda)
9911003
{
9921004
if (lambda.Body.NodeType == ExpressionType.MemberAccess && sep == " ")
@@ -1005,6 +1017,8 @@ protected virtual object VisitLambda(LambdaExpression lambda)
10051017

10061018
protected virtual object VisitBinary(BinaryExpression b)
10071019
{
1020+
var skipParameterizationForThisVisit = false;
1021+
10081022
object left, right;
10091023
var operand = BindOperant(b.NodeType); //sep= " " ??
10101024
if (operand == "AND" || operand == "OR")
@@ -1026,7 +1040,7 @@ protected virtual object VisitBinary(BinaryExpression b)
10261040
if (left as PartialSqlString == null && right as PartialSqlString == null)
10271041
{
10281042
var result = Expression.Lambda(b).Compile().DynamicInvoke();
1029-
return new PartialSqlString(DialectProvider.GetQuotedValue(result, result.GetType()));
1043+
return result;
10301044
}
10311045

10321046
if (left as PartialSqlString == null)
@@ -1039,8 +1053,15 @@ protected virtual object VisitBinary(BinaryExpression b)
10391053
left = Visit(b.Left);
10401054
right = Visit(b.Right);
10411055

1056+
if (visitedExpressionIsTableColumn || (right is DateTimeOffset))
1057+
skipParameterizationForThisVisit = true;
1058+
10421059
var leftEnum = left as EnumMemberAccess;
10431060
var rightEnum = right as EnumMemberAccess;
1061+
1062+
if (leftEnum != null && rightEnum != null)
1063+
skipParameterizationForThisVisit = true;
1064+
10441065
var rightNeedsCoercing = leftEnum != null && rightEnum == null;
10451066
var leftNeedsCoercing = rightEnum != null && leftEnum == null;
10461067

@@ -1049,7 +1070,9 @@ protected virtual object VisitBinary(BinaryExpression b)
10491070
var rightPartialSql = right as PartialSqlString;
10501071
if (rightPartialSql == null)
10511072
{
1052-
right = DialectProvider.GetQuotedValue(right, leftEnum.EnumType);
1073+
right = SkipParameterizationForThisExpression
1074+
? DialectProvider.GetQuotedValue(right, leftEnum.EnumType)
1075+
: DialectProvider.GetValue(right, leftEnum.EnumType);
10531076
}
10541077
}
10551078
else if (leftNeedsCoercing)
@@ -1068,13 +1091,21 @@ protected virtual object VisitBinary(BinaryExpression b)
10681091
else if (left as PartialSqlString == null)
10691092
left = DialectProvider.GetQuotedValue(left, left != null ? left.GetType() : null);
10701093
else if (right as PartialSqlString == null)
1071-
right = DialectProvider.GetQuotedValue(right, right != null ? right.GetType() : null);
1072-
1094+
{
1095+
right = SkipParameterizationForThisExpression
1096+
? DialectProvider.GetQuotedValue(right, right != null ? right.GetType() : null)
1097+
: DialectProvider.GetValue(right, right != null ? right.GetType() : null);
1098+
}
10731099
}
10741100

10751101
if (operand == "=" && right.ToString().Equals("null", StringComparison.OrdinalIgnoreCase)) operand = "is";
10761102
else if (operand == "<>" && right.ToString().Equals("null", StringComparison.OrdinalIgnoreCase)) operand = "is not";
10771103

1104+
if (operand == "AND" || operand == "OR" || operand == "is" || operand == "is not")
1105+
skipParameterizationForThisVisit = true;
1106+
1107+
DoParameterization(skipParameterizationForThisVisit, ref right, b.Right);
1108+
10781109
switch (operand)
10791110
{
10801111
case "MOD":
@@ -1085,6 +1116,17 @@ protected virtual object VisitBinary(BinaryExpression b)
10851116
}
10861117
}
10871118

1119+
private void DoParameterization(bool skipParameterizationForThisVisit, ref object right, Expression rightExpression)
1120+
{
1121+
if (skipParameterizationForThisVisit)
1122+
return;
1123+
1124+
if (SkipParameterizationForThisExpression)
1125+
return;
1126+
1127+
ConvertToPlaceholderAndParameter(ref right, rightExpression);
1128+
}
1129+
10881130
protected virtual object VisitMemberAccess(MemberExpression m)
10891131
{
10901132
if (m.Expression != null
@@ -1103,6 +1145,9 @@ protected virtual object VisitMemberAccess(MemberExpression m)
11031145
}
11041146

11051147
var tableDef = modelType.GetModelDefinition();
1148+
if (tableDef != null)
1149+
visitedExpressionIsTableColumn = true;
1150+
11061151
if (propertyInfo.PropertyType.IsEnum)
11071152
return new EnumMemberAccess(
11081153
GetQuotedColumnName(tableDef, m.Member.Name), propertyInfo.PropertyType);
@@ -1609,14 +1654,20 @@ protected virtual object VisitColumnAccessMethod(MethodCallExpression m)
16091654
return new PartialSqlString(statement);
16101655
}
16111656

1657+
protected virtual void ConvertToPlaceholderAndParameter(ref object right, Expression rightExpression)
1658+
{
1659+
}
1660+
16121661
public IDbDataParameter CreateParam(string name,
16131662
object value = null,
16141663
ParameterDirection direction = ParameterDirection.Input,
1615-
DbType? dbType = null)
1664+
DbType? dbType = null,
1665+
DataRowVersion sourceVersion = DataRowVersion.Default)
16161666
{
16171667
var p = new OrmLiteDataParameter {
16181668
ParameterName = DialectProvider.GetParam(name),
1619-
Direction = direction
1669+
Direction = direction,
1670+
SourceVersion = sourceVersion
16201671
};
16211672
if (value != null)
16221673
{
@@ -1631,7 +1682,6 @@ public IDbDataParameter CreateParam(string name,
16311682

16321683
return p;
16331684
}
1634-
16351685
public IUntypedSqlExpression GetUntyped()
16361686
{
16371687
return new UntypedSqlExpressionProxy<T>(this);

0 commit comments

Comments
 (0)