Skip to content
This repository was archived by the owner on Dec 24, 2022. It is now read-only.

Commit 73751a5

Browse files
committed
Add support for EnumFlags in SqlExpression
1 parent 1ae18e1 commit 73751a5

File tree

7 files changed

+52
-29
lines changed

7 files changed

+52
-29
lines changed

src/ServiceStack.OrmLite.Oracle/OracleOrmLiteDialectProvider.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ public override void PrepareParameterizedInsertStatement<T>(IDbCommand dbCommand
358358
public override void SetParameterValues<T>(IDbCommand dbCmd, object obj)
359359
{
360360
var modelDef = GetModel(typeof(T));
361-
var fieldMap = modelDef.GetFieldDefinitionMap(SanitizeFieldNameForParamName);
361+
var fieldMap = GetFieldDefinitionMap(modelDef);
362362

363363
foreach (IDataParameter p in dbCmd.Parameters)
364364
{

src/ServiceStack.OrmLite/Expressions/SqlExpression.cs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -911,15 +911,15 @@ protected virtual object VisitBinary(BinaryExpression b)
911911
var rightPartialSql = right as PartialSqlString;
912912
if (rightPartialSql == null)
913913
{
914-
right = ConvertToEnum(leftEnum.EnumType, right.ToString(), right);
914+
right = DialectProvider.GetQuotedValue(right, leftEnum.EnumType);
915915
}
916916
}
917917
else if (leftNeedsCoercing)
918918
{
919919
var leftPartialSql = left as PartialSqlString;
920920
if (leftPartialSql == null)
921921
{
922-
left = ConvertToEnum(rightEnum.EnumType, left.ToString(), left);
922+
left = DialectProvider.GetQuotedValue(left, rightEnum.EnumType);
923923
}
924924
}
925925
else if (left as PartialSqlString == null && right as PartialSqlString == null)
@@ -947,16 +947,6 @@ protected virtual object VisitBinary(BinaryExpression b)
947947
}
948948
}
949949

950-
private string ConvertToEnum(Type enumType, string enumStr, object otherExpr)
951-
{
952-
//enum value was returned by Visit(b.Right)
953-
long numvericVal;
954-
var result = Int64.TryParse(enumStr, out numvericVal)
955-
? DialectProvider.GetQuotedValue(Enum.ToObject(enumType, numvericVal).ToString(), typeof(string))
956-
: DialectProvider.GetQuotedValue(otherExpr, otherExpr.GetType());
957-
return result;
958-
}
959-
960950
protected virtual object VisitMemberAccess(MemberExpression m)
961951
{
962952
if (m.Expression != null

src/ServiceStack.OrmLite/IOrmLiteDialectProvider.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ string GetColumnDefinition(
8686

8787
void SetParameterValues<T>(IDbCommand dbCmd, object obj);
8888

89+
Dictionary<string, FieldDefinition> GetFieldDefinitionMap(ModelDefinition modelDef);
90+
91+
object GetFieldValue(FieldDefinition fieldDef, object value);
92+
8993
string ToUpdateRowStatement(object objWithProperties, ICollection<string> UpdateFields = null);
9094

9195
string ToDeleteRowStatement(object objWithProperties);

src/ServiceStack.OrmLite/ModelDefinition.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public Dictionary<string, FieldDefinition> GetFieldDefinitionMap(Func<string, st
9898
{
9999
if (fieldDefinitionMap == null || fieldNameSanitizer != sanitizeFieldName)
100100
{
101-
fieldDefinitionMap = new Dictionary<string, FieldDefinition>();
101+
fieldDefinitionMap = new Dictionary<string, FieldDefinition>(StringComparer.OrdinalIgnoreCase);
102102
fieldNameSanitizer = sanitizeFieldName;
103103
foreach (var fieldDef in FieldDefinitionsArray)
104104
{

src/ServiceStack.OrmLite/OrmLiteConfigExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ internal static ModelDefinition GetModelDefinition(this Type modelType)
114114
: propertyInfo.PropertyType;
115115

116116
Type treatAsType = null;
117-
if (propertyType.IsEnum && propertyType.HasAttribute<FlagsAttribute>())
117+
if (propertyType.IsEnumFlags())
118118
{
119119
treatAsType = Enum.GetUnderlyingType(propertyType);
120120
}

src/ServiceStack.OrmLite/OrmLiteDialectProviderBase.cs

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ public virtual void SetParameter(FieldDefinition fieldDef, IDbDataParameter p)
709709
public virtual void SetParameterValues<T>(IDbCommand dbCmd, object obj)
710710
{
711711
var modelDef = GetModel(typeof(T));
712-
var fieldMap = modelDef.GetFieldDefinitionMap(SanitizeFieldNameForParamName);
712+
var fieldMap = GetFieldDefinitionMap(modelDef);
713713

714714
foreach (IDataParameter p in dbCmd.Parameters)
715715
{
@@ -724,6 +724,11 @@ public virtual void SetParameterValues<T>(IDbCommand dbCmd, object obj)
724724
}
725725
}
726726

727+
public Dictionary<string, FieldDefinition> GetFieldDefinitionMap(ModelDefinition modelDef)
728+
{
729+
return modelDef.GetFieldDefinitionMap(SanitizeFieldNameForParamName);
730+
}
731+
727732
public virtual void SetParameterValue<T>(FieldDefinition fieldDef, IDataParameter p, object obj)
728733
{
729734
var value = GetValueOrDbNull<T>(fieldDef, obj);
@@ -736,12 +741,17 @@ protected virtual object GetValue<T>(FieldDefinition fieldDef, object obj)
736741
? fieldDef.GetValue(obj)
737742
: GetAnonValue<T>(fieldDef, obj);
738743

744+
return GetFieldValue(fieldDef, value);
745+
}
746+
747+
public object GetFieldValue(FieldDefinition fieldDef, object value)
748+
{
739749
if (value != null)
740750
{
741751
if (fieldDef.IsRefType)
742752
{
743753
//Let ADO.NET providers handle byte[]
744-
if (fieldDef.FieldType == typeof(byte[]))
754+
if (fieldDef.FieldType == typeof (byte[]))
745755
{
746756
return value;
747757
}
@@ -754,9 +764,9 @@ protected virtual object GetValue<T>(FieldDefinition fieldDef, object obj)
754764
? enumValue.Trim('"')
755765
: null;
756766
}
757-
if (fieldDef.FieldType == typeof(TimeSpan))
767+
if (fieldDef.FieldType == typeof (TimeSpan))
758768
{
759-
var timespan = (TimeSpan)value;
769+
var timespan = (TimeSpan) value;
760770
return timespan.Ticks;
761771
}
762772
}
@@ -1338,6 +1348,22 @@ public virtual string GetQuotedValue(object value, Type fieldType)
13381348
return dialectProvider.GetQuotedValue(dialectProvider.StringSerializer.SerializeToString(value));
13391349
}
13401350

1351+
if (fieldType.IsEnum)
1352+
{
1353+
var isEnumFlags = fieldType.IsEnumFlags();
1354+
long enumValue;
1355+
if (!isEnumFlags && Int64.TryParse(value.ToString(), out enumValue))
1356+
{
1357+
value = Enum.ToObject(fieldType, enumValue).ToString();
1358+
}
1359+
1360+
var enumString = dialectProvider.StringSerializer.SerializeToString(value);
1361+
1362+
return !isEnumFlags
1363+
? dialectProvider.GetQuotedValue(enumString.Trim('"'))
1364+
: enumString;
1365+
}
1366+
13411367
var typeCode = fieldType.GetTypeCode();
13421368
switch (typeCode)
13431369
{
@@ -1363,14 +1389,6 @@ public virtual string GetQuotedValue(object value, Type fieldType)
13631389

13641390
if (fieldType == typeof(TimeSpan))
13651391
return ((TimeSpan)value).Ticks.ToString(CultureInfo.InvariantCulture);
1366-
1367-
if (fieldType.IsEnum)
1368-
{
1369-
var enumValue = dialectProvider.StringSerializer.SerializeToString(value);
1370-
return enumValue != null
1371-
? dialectProvider.GetQuotedValue(enumValue.Trim('"'))
1372-
: null;
1373-
}
13741392

13751393
return ShouldQuoteValue(fieldType)
13761394
? dialectProvider.GetQuotedValue(value.ToString())

src/ServiceStack.OrmLite/OrmLiteReadExtensions.cs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,17 +138,28 @@ internal static void SetParameters<T>(this IDbCommand dbCmd, object anonType, bo
138138
dbCmd.Parameters.Clear();
139139
lastQueryType = null;
140140

141+
var dialectProvider = OrmLiteConfig.DialectProvider;
142+
var fieldMap = typeof(T).IsUserType() //Ensure T != Scalar<int>()
143+
? dialectProvider.GetFieldDefinitionMap(typeof(T).GetModelDefinition())
144+
: null;
145+
141146
anonType.ForEachParam<T>(excludeDefaults, (pi, columnName, value) =>
142147
{
143148
var p = dbCmd.CreateParameter();
144149
p.ParameterName = columnName;
145-
p.DbType = OrmLiteConfig.DialectProvider.GetColumnDbType(pi.PropertyType);
150+
p.DbType = dialectProvider.GetColumnDbType(pi.PropertyType);
146151
p.Direction = ParameterDirection.Input;
152+
153+
FieldDefinition fieldDef;
154+
if (fieldMap != null && fieldMap.TryGetValue(columnName, out fieldDef))
155+
value = dialectProvider.GetFieldValue(fieldDef, value);
156+
147157
p.Value = value == null ?
148158
DBNull.Value
149159
: p.DbType == DbType.String ?
150160
value.ToString() :
151161
value;
162+
152163
dbCmd.Parameters.Add(p);
153164
});
154165
}
@@ -429,7 +440,7 @@ internal static List<T> SelectNonDefaults<T>(this IDbCommand dbCmd, object filte
429440

430441
internal static List<T> SelectNonDefaults<T>(this IDbCommand dbCmd, string sql, object anonType = null)
431442
{
432-
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults: false);
443+
if (anonType != null) dbCmd.SetParameters<T>(anonType, excludeDefaults:true);
433444

434445
return dbCmd.ConvertToList<T>(OrmLiteConfig.DialectProvider.ToSelectStatement(typeof(T), sql));
435446
}

0 commit comments

Comments
 (0)