Skip to content
This repository was archived by the owner on Dec 24, 2022. It is now read-only.

Commit 4119e1a

Browse files
committed
Merge branch 'pr/446'
2 parents e857320 + 36875d9 commit 4119e1a

17 files changed

+323
-56
lines changed

src/ServiceStack.OrmLite.Oracle.Tests/ServiceStack.OrmLite.Oracle.Tests.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,9 @@
148148
<Compile Include="..\..\tests\ServiceStack.OrmLite.Tests\Expression\SelectExpressionTests.cs">
149149
<Link>Expression\SelectExpressionTests.cs</Link>
150150
</Compile>
151+
<Compile Include="..\..\tests\ServiceStack.OrmLite.Tests\Expression\SqlExpressionParamTests.cs">
152+
<Link>Expression\SqlExpressionParamTests.cs</Link>
153+
</Compile>
151154
<Compile Include="..\..\tests\ServiceStack.OrmLite.Tests\Expression\SqlExpressionTests.cs">
152155
<Link>Expression\SqlExpressionTests.cs</Link>
153156
</Compile>

src/ServiceStack.OrmLite.Oracle/OracleOrmLiteDialectProvider.cs

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,18 @@ public static OracleOrmLiteDialectProvider Instance
5959
private readonly DbProviderFactory _factory;
6060
private readonly OracleTimestampConverter _timestampConverter;
6161

62+
public bool ParameterizeStatement;
63+
6264
public OracleOrmLiteDialectProvider()
6365
: this(false, false)
6466
{
6567
}
6668

67-
public OracleOrmLiteDialectProvider(bool compactGuid, bool quoteNames, string clientProvider = OdpProvider)
69+
public OracleOrmLiteDialectProvider(bool compactGuid, bool quoteNames, string clientProvider = OdpProvider, bool parameterizeStatement = false)
6870
{
6971
ClientProvider = clientProvider;
7072
CompactGuid = compactGuid;
73+
ParameterizeStatement = parameterizeStatement;
7174
QuoteNames = quoteNames;
7275
BoolColumnDefinition = "NUMBER(1)";
7376
GuidColumnDefinition = CompactGuid ? CompactGuidDefinition : StringGuidDefinition;
@@ -298,6 +301,73 @@ public override string GetQuotedValue(object value, Type fieldType)
298301
return base.GetQuotedValue(value, fieldType);
299302
}
300303

304+
public override object GetParamValue(object value, Type fieldType)
305+
{
306+
if (!ParameterizeStatement)
307+
return GetQuotedValue(value, fieldType);
308+
309+
if (value == null) return DBNull.Value;
310+
311+
if (fieldType == typeof(Guid))
312+
{
313+
var guid = (Guid)value;
314+
315+
if (CompactGuid)
316+
return guid.ToByteArray();
317+
318+
return guid.ToString();
319+
}
320+
321+
if (fieldType == typeof(DateTimeOffset) || fieldType == typeof(DateTimeOffset?))
322+
{
323+
return GetQuotedDateTimeOffsetValue((DateTimeOffset)value);
324+
}
325+
326+
if ((value is TimeSpan) && (fieldType == typeof(Int64) || fieldType == typeof(Int64?)))
327+
{
328+
var longValue = ((TimeSpan)value).Ticks;
329+
return base.GetQuotedValue(longValue, fieldType);
330+
}
331+
332+
if (fieldType == typeof(TimeSpan))
333+
return ((TimeSpan)value).Ticks;
334+
335+
if (fieldType == typeof(bool?) || fieldType == typeof(bool))
336+
{
337+
var boolValue = (bool)value;
338+
return boolValue ? 1 : 0;
339+
}
340+
341+
if (fieldType.IsEnum)
342+
{
343+
if (value is int && !fieldType.IsEnumFlags())
344+
{
345+
value = fieldType.GetEnumName(value);
346+
}
347+
348+
var enumValue = StringSerializer.SerializeToString(value);
349+
// Oracle stores empty strings in varchar columns as null so match that behavior here
350+
if (enumValue == null)
351+
return null;
352+
enumValue = enumValue.Trim('"');
353+
return enumValue == ""
354+
? "null"
355+
: enumValue;
356+
}
357+
358+
if (fieldType == typeof(byte[]))
359+
{
360+
return "hextoraw('" + BitConverter.ToString((byte[])value).Replace("-", "") + "')";
361+
}
362+
363+
if (fieldType.IsRefType())
364+
{
365+
return StringSerializer.SerializeToString(value);
366+
}
367+
368+
return value;
369+
}
370+
301371
const string IsoDateFormat = "yyyy-MM-dd";
302372
const string IsoTimeFormat = "HH:mm:ss";
303373
const string IsoMillisecondFormat = "fffffff";

src/ServiceStack.OrmLite.Oracle/OracleSqlExpression.cs

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
namespace ServiceStack.OrmLite.Oracle
66
{
7-
public class OracleSqlExpression<T> : SqlExpression<T>
7+
public class OracleSqlExpression<T> : ParameterizedSqlExpression<T>
88
{
99
public OracleSqlExpression(IOrmLiteDialectProvider dialectProvider)
1010
: base(dialectProvider) {}
@@ -19,18 +19,29 @@ protected override object VisitColumnAccessMethod(MethodCallExpression m)
1919
if (args.Count == 2)
2020
{
2121
var length = Int32.Parse(args[1].ToString());
22-
return new PartialSqlString(string.Format("subStr({0},{1},{2})",
23-
quotedColName,
24-
startIndex,
25-
length));
22+
return new PartialSqlString(string.Format(
23+
"subStr({0},{1},{2})", quotedColName, startIndex, length));
2624
}
2725

28-
return new PartialSqlString(string.Format("subStr({0},{1})",
29-
quotedColName,
30-
startIndex));
26+
return new PartialSqlString(string.Format(
27+
"subStr({0},{1})", quotedColName, startIndex));
3128
}
3229
return base.VisitColumnAccessMethod(m);
3330
}
31+
32+
protected override void ConvertToPlaceholderAndParameter(ref object right)
33+
{
34+
if (!((OracleOrmLiteDialectProvider)DialectProvider).ParameterizeStatement)
35+
return;
36+
37+
var paramName = Params.Count.ToString();
38+
var paramValue = right;
39+
40+
var parameter = CreateParam(paramName, paramValue);
41+
Params.Add(parameter);
42+
43+
right = parameter.ParameterName;
44+
}
3445
}
3546
}
3647

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
using System;
2+
using System.Linq.Expressions;
3+
4+
namespace ServiceStack.OrmLite
5+
{
6+
public abstract partial class ParameterizedSqlExpression<T> : SqlExpression<T>
7+
{
8+
protected bool visitedExpressionIsTableColumn = false;
9+
protected bool SkipParameterizationForThisExpression { get; set; }
10+
11+
protected ParameterizedSqlExpression(IOrmLiteDialectProvider dialectProvider)
12+
: base(dialectProvider)
13+
{
14+
SkipParameterizationForThisExpression = false;
15+
}
16+
17+
protected internal override object Visit(Expression exp)
18+
{
19+
visitedExpressionIsTableColumn = false;
20+
return base.Visit(exp);
21+
}
22+
23+
protected internal override object VisitJoin(Expression exp)
24+
{
25+
SkipParameterizationForThisExpression = true;
26+
var visitedExpression = Visit(exp);
27+
SkipParameterizationForThisExpression = false;
28+
return visitedExpression;
29+
}
30+
31+
protected virtual void ConvertToPlaceholderAndParameter(ref object right)
32+
{
33+
}
34+
35+
public override object GetValue(object value, Type type)
36+
{
37+
return SkipParameterizationForThisExpression
38+
? DialectProvider.GetQuotedValue(value, type)
39+
: DialectProvider.GetParamValue(value, type);
40+
}
41+
42+
protected override void VisitFilter(string operand, object originalLeft, object originalRight, ref object left, ref object right)
43+
{
44+
if (SkipParameterizationForThisExpression)
45+
return;
46+
47+
if (visitedExpressionIsTableColumn || (originalRight is DateTimeOffset))
48+
return;
49+
50+
var leftEnum = originalLeft as EnumMemberAccess;
51+
var rightEnum = originalRight as EnumMemberAccess;
52+
53+
if (leftEnum != null && rightEnum != null)
54+
return;
55+
56+
if (operand == "AND" || operand == "OR" || operand == "is" || operand == "is not")
57+
return;
58+
59+
ConvertToPlaceholderAndParameter(ref right);
60+
}
61+
62+
protected virtual void OnVisitMemberType(Type modelType)
63+
{
64+
var tableDef = modelType.GetModelDefinition();
65+
if (tableDef != null)
66+
visitedExpressionIsTableColumn = true;
67+
}
68+
}
69+
}

src/ServiceStack.OrmLite/Expressions/ReadExpressionCommandExtensions.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,21 @@ internal static List<T> Select<T>(this IDbCommand dbCmd, Func<SqlExpression<T>,
1414
var expr = dbCmd.GetDialectProvider().SqlExpression<T>();
1515
string sql = expression(expr).SelectInto<T>();
1616

17-
return dbCmd.ExprConvertToList<T>(sql);
17+
return dbCmd.ExprConvertToList<T>(sql, expr.Params);
1818
}
1919

2020
internal static List<Into> Select<Into, From>(this IDbCommand dbCmd, Func<SqlExpression<From>, SqlExpression<From>> expression)
2121
{
2222
var expr = dbCmd.GetDialectProvider().SqlExpression<From>();
2323
string sql = expression(expr).SelectInto<Into>();
2424

25-
return dbCmd.ExprConvertToList<Into>(sql);
25+
return dbCmd.ExprConvertToList<Into>(sql, expr.Params);
2626
}
2727

2828
internal static List<Into> Select<Into, From>(this IDbCommand dbCmd, SqlExpression<From> expression)
2929
{
3030
string sql = expression.SelectInto<Into>();
31-
return dbCmd.ExprConvertToList<Into>(sql);
31+
return dbCmd.ExprConvertToList<Into>(sql, expression.Params);
3232
}
3333

3434
internal static List<T> Select<T>(this IDbCommand dbCmd, SqlExpression<T> expression)
@@ -43,7 +43,7 @@ internal static List<T> Select<T>(this IDbCommand dbCmd, Expression<Func<T, bool
4343
var expr = dbCmd.GetDialectProvider().SqlExpression<T>();
4444
string sql = expr.Where(predicate).SelectInto<T>();
4545

46-
return dbCmd.ExprConvertToList<T>(sql);
46+
return dbCmd.ExprConvertToList<T>(sql, expr.Params);
4747
}
4848

4949
internal static T Single<T>(this IDbCommand dbCmd, Func<SqlExpression<T>, SqlExpression<T>> expression)
@@ -63,7 +63,7 @@ internal static T Single<T>(this IDbCommand dbCmd, SqlExpression<T> expression)
6363
{
6464
string sql = expression.Limit(1).SelectInto<T>();
6565

66-
return dbCmd.ExprConvertTo<T>(sql);
66+
return dbCmd.ExprConvertTo<T>(sql, expression.Params);
6767
}
6868

6969
public static TKey Scalar<T, TKey>(this IDbCommand dbCmd, Expression<Func<T, TKey>> field)
@@ -80,7 +80,7 @@ internal static TKey Scalar<T, TKey>(this IDbCommand dbCmd,
8080
var ev = dbCmd.GetDialectProvider().SqlExpression<T>();
8181
ev.Select(field).Where(predicate);
8282
string sql = ev.SelectInto<T>();
83-
return dbCmd.Scalar<TKey>(sql);
83+
return dbCmd.Scalar<TKey>(sql, ev.Params);
8484
}
8585

8686
internal static long Count<T>(this IDbCommand dbCmd)
@@ -108,12 +108,12 @@ internal static long Count<T>(this IDbCommand dbCmd, Expression<Func<T, bool>> p
108108
var ev = dbCmd.GetDialectProvider().SqlExpression<T>();
109109
ev.Where(predicate);
110110
var sql = ev.ToCountStatement();
111-
return GetCount(dbCmd, sql);
111+
return GetCount(dbCmd, sql, ev.Params);
112112
}
113113

114-
internal static long GetCount(this IDbCommand dbCmd, string sql)
114+
internal static long GetCount(this IDbCommand dbCmd, string sql, IEnumerable<IDbDataParameter> sqlParams = null)
115115
{
116-
return dbCmd.Column<long>(sql).Sum();
116+
return dbCmd.Column<long>(sql, sqlParams).Sum();
117117
}
118118

119119
internal static long RowCount<T>(this IDbCommand dbCmd, SqlExpression<T> expression)

src/ServiceStack.OrmLite/Expressions/SqlExpression.Join.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ private SqlExpression<T> InternalJoin<Source, Target>(string joinType,
8686

8787
private string InternalCreateSqlFromExpression(Expression joinExpr, bool isCrossJoin)
8888
{
89-
return "{0} {1}".Fmt((isCrossJoin ? "WHERE" : "ON"), Visit(joinExpr).ToString());
89+
return "{0} {1}".Fmt((isCrossJoin ? "WHERE" : "ON"), VisitJoin(joinExpr).ToString());
9090
}
9191

9292
private string InternalCreateSqlFromDefinitions(ModelDefinition sourceDef, ModelDefinition targetDef, bool isCrossJoin)

0 commit comments

Comments
 (0)