Skip to content

Commit 4f805f0

Browse files
authored
Fix handling of enum default values in RDF and RDG (#63086)
* Fix handling of enum default values in RDF and RDG * Feedback
1 parent 52b1c18 commit 4f805f0

File tree

3 files changed

+131
-55
lines changed

3 files changed

+131
-55
lines changed

src/Http/Http.Extensions/src/RequestDelegateFactory.cs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1801,7 +1801,7 @@ private static Expression BindParameterFromValue(ParameterInfo parameter, Expres
18011801
? Expression.IfThen(TempSourceStringNotNullExpr, tryParseExpression)
18021802
: Expression.IfThenElse(TempSourceStringNotNullExpr, tryParseExpression,
18031803
Expression.Assign(argument,
1804-
Expression.Constant(parameter.DefaultValue, parameter.ParameterType)));
1804+
CreateDefaultValueExpression(parameter.DefaultValue, parameter.ParameterType)));
18051805

18061806
var loopExit = Expression.Label();
18071807

@@ -1963,7 +1963,7 @@ private static Expression BindParameterFromExpression(
19631963
return Expression.Block(
19641964
Expression.Condition(Expression.NotEqual(valueExpression, Expression.Constant(null)),
19651965
valueExpression,
1966-
Expression.Convert(Expression.Constant(parameter.DefaultValue), parameter.ParameterType)));
1966+
Expression.Convert(CreateDefaultValueExpression(parameter.DefaultValue, parameter.ParameterType), parameter.ParameterType)));
19671967
}
19681968

19691969
private static Expression BindParameterFromProperty(ParameterInfo parameter, MemberExpression property, PropertyInfo itemProperty, string key, RequestDelegateFactoryContext factoryContext, string source)
@@ -1981,6 +1981,34 @@ private static Expression BindParameterFromProperty(ParameterInfo parameter, Mem
19811981
type == typeof(StringValues?) ? typeof(StringValues?) :
19821982
null;
19831983

1984+
private static Expression CreateDefaultValueExpression(object? defaultValue, Type parameterType)
1985+
{
1986+
if (defaultValue is null)
1987+
{
1988+
return Expression.Constant(null, parameterType);
1989+
}
1990+
1991+
var underlyingType = Nullable.GetUnderlyingType(parameterType);
1992+
var isNullable = underlyingType != null;
1993+
var targetType = isNullable ? underlyingType! : parameterType;
1994+
var converted = defaultValue;
1995+
1996+
// Apply a conversion for scenarios where the default value's type
1997+
// doesn't match the parameter type
1998+
if (targetType.IsEnum && defaultValue.GetType() != targetType)
1999+
{
2000+
converted = Enum.ToObject(targetType, defaultValue);
2001+
}
2002+
else if (!targetType.IsAssignableFrom(defaultValue.GetType()))
2003+
{
2004+
converted = Convert.ChangeType(defaultValue, targetType, CultureInfo.InvariantCulture);
2005+
}
2006+
2007+
var constant = Expression.Constant(converted, targetType);
2008+
// Cast nullable types as needed
2009+
return isNullable ? Expression.Convert(constant, parameterType) : constant;
2010+
}
2011+
19842012
private static Expression BindParameterFromRouteValueOrQueryString(ParameterInfo parameter, string key, RequestDelegateFactoryContext factoryContext)
19852013
{
19862014
var routeValue = GetValueFromProperty(RouteValuesExpr, RouteValuesIndexerProperty, key);
@@ -2358,7 +2386,7 @@ private static Expression BindParameterFromBody(ParameterInfo parameter, bool al
23582386
{
23592387
// Convert(bodyValue ?? SomeDefault, Todo)
23602388
return Expression.Convert(
2361-
Expression.Coalesce(BodyValueExpr, Expression.Constant(parameter.DefaultValue)),
2389+
Expression.Coalesce(BodyValueExpr, CreateDefaultValueExpression(parameter.DefaultValue, typeof(object))),
23622390
parameter.ParameterType);
23632391
}
23642392

src/Http/Http.Extensions/test/RequestDelegateGenerator/RequestDelegateCreationTests.SpecialTypes.cs

Lines changed: 60 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -154,56 +154,66 @@ public static object[][] DefaultValues
154154
{
155155
get
156156
{
157-
return new[]
158-
{
159-
new object[] { "string?", "default", default(string), true },
160-
new object[] { "string", "\"test\"", "test", true },
161-
new object[] { "string", "\"a\" + \"b\"", "ab", true },
162-
new object[] { "DateOnly?", "default", default(DateOnly?), false },
163-
new object[] { "bool", "default", default(bool), true },
164-
new object[] { "bool", "false", false, true },
165-
new object[] { "bool", "true", true, true},
166-
new object[] { "System.Threading.CancellationToken", "default", default(CancellationToken), false },
167-
new object[] { "Todo?", "default", default(Todo), false },
168-
new object[] { "char", "\'a\'", 'a', true },
169-
new object[] { "int", "default", 0, true },
170-
new object[] { "int", "1234", 1234, true },
171-
new object[] { "int", "1234 * 4", 1234 * 4, true },
172-
new object[] { "double", "1.0", 1.0, true },
173-
new object[] { "double", "double.NaN", double.NaN, true },
174-
new object[] { "double", "double.PositiveInfinity", double.PositiveInfinity, true },
175-
new object[] { "double", "double.NegativeInfinity", double.NegativeInfinity, true },
176-
new object[] { "double", "double.E", double.E, true },
177-
new object[] { "double", "double.Epsilon", double.Epsilon, true },
178-
new object[] { "double", "double.NegativeZero", double.NegativeZero, true },
179-
new object[] { "double", "double.MaxValue", double.MaxValue, true },
180-
new object[] { "double", "double.MinValue", double.MinValue, true },
181-
new object[] { "double", "double.Pi", double.Pi, true },
182-
new object[] { "double", "double.Tau", double.Tau, true },
183-
new object[] { "float", "float.NaN", float.NaN, true },
184-
new object[] { "float", "float.PositiveInfinity", float.PositiveInfinity, true },
185-
new object[] { "float", "float.NegativeInfinity", float.NegativeInfinity, true },
186-
new object[] { "float", "float.E", float.E, true },
187-
new object[] { "float", "float.Epsilon", float.Epsilon, true },
188-
new object[] { "float", "float.NegativeZero", float.NegativeZero, true },
189-
new object[] { "float", "float.MaxValue", float.MaxValue, true },
190-
new object[] { "float", "float.MinValue", float.MinValue, true },
191-
new object[] { "float", "float.Pi", float.Pi, true },
192-
new object[] { "float", "float.Tau", float.Tau, true },
193-
new object[] {"decimal", "decimal.MaxValue", decimal.MaxValue, true },
194-
new object[] {"decimal", "decimal.MinusOne", decimal.MinusOne, true },
195-
new object[] {"decimal", "decimal.MinValue", decimal.MinValue, true },
196-
new object[] {"decimal", "decimal.One", decimal.One, true },
197-
new object[] {"decimal", "decimal.Zero", decimal.Zero, true },
198-
new object[] {"long", "long.MaxValue", long.MaxValue, true },
199-
new object[] {"long", "long.MinValue", long.MinValue, true },
200-
new object[] {"short", "short.MaxValue", short.MaxValue, true },
201-
new object[] {"short", "short.MinValue", short.MinValue, true },
202-
new object[] {"ulong", "ulong.MaxValue", ulong.MaxValue, true },
203-
new object[] {"ulong", "ulong.MinValue", ulong.MinValue, true },
204-
new object[] {"ushort", "ushort.MaxValue", ushort.MaxValue, true },
205-
new object[] {"ushort", "ushort.MinValue", ushort.MinValue, true },
206-
};
157+
return
158+
[
159+
["string?", "default", default(string), true],
160+
["string", "\"test\"", "test", true],
161+
["string", "\"a\" + \"b\"", "ab", true],
162+
["DateOnly?", "default", default(DateOnly?), false],
163+
["bool", "default", default(bool), true],
164+
["bool", "false", false, true],
165+
["bool", "true", true, true],
166+
["System.Threading.CancellationToken", "default", default(CancellationToken), false],
167+
["Todo?", "default", default(Todo), false],
168+
["char", "\'a\'", 'a', true],
169+
["int", "default", 0, true],
170+
["int", "1234", 1234, true],
171+
["int", "1234 * 4", 1234 * 4, true],
172+
["double", "1.0", 1.0, true],
173+
["double", "double.NaN", double.NaN, true],
174+
["double", "double.PositiveInfinity", double.PositiveInfinity, true],
175+
["double", "double.NegativeInfinity", double.NegativeInfinity, true],
176+
["double", "double.E", double.E, true],
177+
["double", "double.Epsilon", double.Epsilon, true],
178+
["double", "double.NegativeZero", double.NegativeZero, true],
179+
["double", "double.MaxValue", double.MaxValue, true],
180+
["double", "double.MinValue", double.MinValue, true],
181+
["double", "double.Pi", double.Pi, true],
182+
["double", "double.Tau", double.Tau, true],
183+
["float", "float.NaN", float.NaN, true],
184+
["float", "float.PositiveInfinity", float.PositiveInfinity, true],
185+
["float", "float.NegativeInfinity", float.NegativeInfinity, true],
186+
["float", "float.E", float.E, true],
187+
["float", "float.Epsilon", float.Epsilon, true],
188+
["float", "float.NegativeZero", float.NegativeZero, true],
189+
["float", "float.MaxValue", float.MaxValue, true],
190+
["float", "float.MinValue", float.MinValue, true],
191+
["float", "float.Pi", float.Pi, true],
192+
["float", "float.Tau", float.Tau, true],
193+
["decimal", "decimal.MaxValue", decimal.MaxValue, true],
194+
["decimal", "decimal.MinusOne", decimal.MinusOne, true],
195+
["decimal", "decimal.MinValue", decimal.MinValue, true],
196+
["decimal", "decimal.One", decimal.One, true],
197+
["decimal", "decimal.Zero", decimal.Zero, true],
198+
["long", "long.MaxValue", long.MaxValue, true],
199+
["long", "long.MinValue", long.MinValue, true],
200+
["long", "(long)3.14", (long)3, true],
201+
["short", "short.MaxValue", short.MaxValue, true],
202+
["short", "short.MinValue", short.MinValue, true],
203+
["ulong", "ulong.MaxValue", ulong.MaxValue, true],
204+
["ulong", "ulong.MinValue", ulong.MinValue, true],
205+
["ushort", "ushort.MaxValue", ushort.MaxValue, true],
206+
["ushort", "ushort.MinValue", ushort.MinValue, true],
207+
["TodoStatus", "TodoStatus.Done", TodoStatus.Done, true],
208+
["TodoStatus", "TodoStatus.InProgress", TodoStatus.InProgress, true],
209+
["TodoStatus", "TodoStatus.NotDone", TodoStatus.NotDone, true],
210+
["TodoStatus", "(TodoStatus)1", TodoStatus.Done, true],
211+
["MyEnum", "MyEnum.ValueA", MyEnum.ValueA, true],
212+
["MyEnum", "MyEnum.ValueB", MyEnum.ValueB, true],
213+
// Test nullable enum values
214+
["TodoStatus?", "TodoStatus.Done", (TodoStatus?)TodoStatus.Done, false],
215+
["TodoStatus?", "default", default(TodoStatus?), false]
216+
];
207217
}
208218
}
209219

src/Shared/RoslynUtils/SymbolExtensions.cs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,27 @@ public static string GetDefaultValueString(this IParameterSymbol parameterSymbol
177177
{
178178
return !parameterSymbol.HasExplicitDefaultValue
179179
? "null"
180-
: InnerGetDefaultValueString(parameterSymbol.ExplicitDefaultValue);
180+
: InnerGetDefaultValueString(parameterSymbol.ExplicitDefaultValue, parameterSymbol.Type);
181181
}
182182

183-
private static string InnerGetDefaultValueString(object? defaultValue)
183+
private static string InnerGetDefaultValueString(object? defaultValue, ITypeSymbol parameterType)
184184
{
185+
// Handle enum types with proper casting
186+
if (IsEnumType(parameterType, out var enumType))
187+
{
188+
return $"({enumType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}){SymbolDisplay.FormatPrimitive(defaultValue!, false, false)}";
189+
}
190+
191+
// Handle nullable enum types
192+
if (IsNullableEnumType(parameterType, out var underlyingEnumType))
193+
{
194+
if (defaultValue == null)
195+
{
196+
return "default";
197+
}
198+
return $"({underlyingEnumType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)}){SymbolDisplay.FormatPrimitive(defaultValue!, false, false)}";
199+
}
200+
185201
return defaultValue switch
186202
{
187203
string s => SymbolDisplay.FormatLiteral(s, true),
@@ -227,4 +243,26 @@ public static string GetParameterInfoFromConstructorCode(this IParameterSymbol p
227243
}
228244
return "null";
229245
}
246+
247+
private static bool IsEnumType(ITypeSymbol typeSymbol, out ITypeSymbol enumType)
248+
{
249+
enumType = typeSymbol;
250+
return typeSymbol.TypeKind == TypeKind.Enum;
251+
}
252+
253+
private static bool IsNullableEnumType(ITypeSymbol typeSymbol, [NotNullWhen(true)] out ITypeSymbol? underlyingEnumType)
254+
{
255+
underlyingEnumType = null;
256+
if (typeSymbol.OriginalDefinition?.SpecialType == SpecialType.System_Nullable_T &&
257+
typeSymbol is INamedTypeSymbol namedType)
258+
{
259+
var underlyingType = namedType.TypeArguments.FirstOrDefault();
260+
if (underlyingType?.TypeKind == TypeKind.Enum)
261+
{
262+
underlyingEnumType = underlyingType;
263+
return true;
264+
}
265+
}
266+
return false;
267+
}
230268
}

0 commit comments

Comments
 (0)