Skip to content

Commit 6174e0b

Browse files
authored
Fix handling if dictionary inferrence in source generator. (#8962)
1 parent 8b56bc5 commit 6174e0b

File tree

3 files changed

+159
-4
lines changed

3 files changed

+159
-4
lines changed

src/HotChocolate/Core/src/Types.Analyzers/Helpers/TypeReferenceBuilder.cs

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -115,9 +115,8 @@ private static (string TypeStructure, string TypeDefinition, bool IsSimpleType)
115115
isNullable = false;
116116
}
117117

118-
if (underlyingType is INamedTypeSymbol namedType && IsListType(namedType))
118+
if (underlyingType is INamedTypeSymbol namedType && TryGetListElementType(namedType, out var listElementType))
119119
{
120-
var listElementType = namedType.TypeArguments[0];
121120
var (typeStructure, typeDefinition, _) = CreateTypeKey(listElementType);
122121

123122
if (isNullable)
@@ -240,10 +239,13 @@ private static bool IsArrayType(ITypeSymbol namedType, [NotNullWhen(true)] out I
240239
return false;
241240
}
242241

243-
private static bool IsListType(INamedTypeSymbol namedType)
242+
private static bool TryGetListElementType(
243+
INamedTypeSymbol namedType,
244+
[NotNullWhen(true)] out ITypeSymbol? elementType)
244245
{
245246
if (!namedType.IsGenericType)
246247
{
248+
elementType = null;
247249
return false;
248250
}
249251

@@ -254,10 +256,13 @@ private static bool IsListType(INamedTypeSymbol namedType)
254256
if (WellKnownTypes.ListInterfaceTypes.Contains(typeDefinition)
255257
|| WellKnownTypes.ListClassTypes.Contains(typeDefinition))
256258
{
259+
elementType = namedType.TypeArguments[0];
257260
return true;
258261
}
259262

260-
// Check if the type implements any of the known list interfaces
263+
// Check if the type implements any of the known list interfaces.
264+
// This handles cases like Dictionary<K,V> which implements IEnumerable<KeyValuePair<K,V>>.
265+
// We extract the element type from the interface, not from the type's own type arguments.
261266
foreach (var interfaceType in namedType.AllInterfaces)
262267
{
263268
if (!interfaceType.IsGenericType)
@@ -268,6 +273,7 @@ private static bool IsListType(INamedTypeSymbol namedType)
268273
var interfaceDefinition = GetGenericTypeDefinition(interfaceType.OriginalDefinition);
269274
if (WellKnownTypes.ListInterfaceTypes.Contains(interfaceDefinition))
270275
{
276+
elementType = interfaceType.TypeArguments[0];
271277
return true;
272278
}
273279
}
@@ -285,12 +291,14 @@ private static bool IsListType(INamedTypeSymbol namedType)
285291
var baseDefinition = GetGenericTypeDefinition(currentType.OriginalDefinition);
286292
if (WellKnownTypes.ListClassTypes.Contains(baseDefinition))
287293
{
294+
elementType = currentType.TypeArguments[0];
288295
return true;
289296
}
290297

291298
currentType = currentType.BaseType;
292299
}
293300

301+
elementType = null;
294302
return false;
295303
}
296304

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
namespace HotChocolate.Types;
2+
3+
public class CollectionInferenceTests
4+
{
5+
[Fact]
6+
public async Task Infer_Dictionary_As_List_Of_KeyValuePair()
7+
{
8+
await TestHelper.GetGeneratedSourceSnapshot(
9+
"""
10+
using System;
11+
using System.Collections.Generic;
12+
using System.Threading;
13+
using System.Threading.Tasks;
14+
using HotChocolate;
15+
using HotChocolate.Types;
16+
17+
namespace TestNamespace;
18+
19+
[QueryType]
20+
internal static partial class Query
21+
{
22+
public static Task<Dictionary<int, string?>> GetStuffAsync(
23+
CancellationToken cancellationToken)
24+
=> default;
25+
}
26+
""").MatchMarkdownAsync();
27+
}
28+
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Infer_Dictionary_As_List_Of_KeyValuePair
2+
3+
## HotChocolateTypeModule.735550c.g.cs
4+
5+
```csharp
6+
// <auto-generated/>
7+
8+
#nullable enable
9+
#pragma warning disable
10+
11+
using System;
12+
using System.Runtime.CompilerServices;
13+
using HotChocolate;
14+
using HotChocolate.Types;
15+
using HotChocolate.Execution.Configuration;
16+
17+
namespace Microsoft.Extensions.DependencyInjection
18+
{
19+
public static partial class TestsTypesRequestExecutorBuilderExtensions
20+
{
21+
public static IRequestExecutorBuilder AddTestsTypes(this IRequestExecutorBuilder builder)
22+
{
23+
builder.ConfigureDescriptorContext(ctx => ctx.TypeConfiguration.TryAdd(
24+
"Tests::TestNamespace.Query",
25+
global::HotChocolate.Types.OperationTypeNames.Query,
26+
() => global::TestNamespace.Query.Initialize));
27+
builder.ConfigureSchema(
28+
b => b.TryAddRootType(
29+
() => new global::HotChocolate.Types.ObjectType(
30+
d => d.Name(global::HotChocolate.Types.OperationTypeNames.Query)),
31+
HotChocolate.Language.OperationType.Query));
32+
return builder;
33+
}
34+
}
35+
}
36+
37+
```
38+
39+
## Query.WaAdMHmlGJHjtEI4nqY7WA.hc.g.cs
40+
41+
```csharp
42+
// <auto-generated/>
43+
44+
#nullable enable
45+
#pragma warning disable
46+
47+
using System;
48+
using System.Runtime.CompilerServices;
49+
using HotChocolate;
50+
using HotChocolate.Types;
51+
using HotChocolate.Execution.Configuration;
52+
using Microsoft.Extensions.DependencyInjection;
53+
using HotChocolate.Internal;
54+
55+
namespace TestNamespace
56+
{
57+
internal static partial class Query
58+
{
59+
internal static void Initialize(global::HotChocolate.Types.IObjectTypeDescriptor descriptor)
60+
{
61+
var extension = descriptor.Extend();
62+
var configuration = extension.Configuration;
63+
var thisType = typeof(global::TestNamespace.Query);
64+
var bindingResolver = extension.Context.ParameterBindingResolver;
65+
var resolvers = new __Resolvers();
66+
67+
HotChocolate.Internal.ConfigurationHelper.ApplyConfiguration(
68+
extension.Context,
69+
descriptor,
70+
null,
71+
new global::HotChocolate.Types.QueryTypeAttribute());
72+
configuration.ConfigurationsAreApplied = true;
73+
74+
var naming = descriptor.Extend().Context.Naming;
75+
76+
descriptor
77+
.Field(naming.GetMemberName("Stuff", global::HotChocolate.Types.MemberKind.ObjectField))
78+
.ExtendWith(static (field, context) =>
79+
{
80+
var configuration = field.Configuration;
81+
var typeInspector = field.Context.TypeInspector;
82+
var bindingResolver = field.Context.ParameterBindingResolver;
83+
var naming = field.Context.Naming;
84+
85+
configuration.Type = global::HotChocolate.Types.Descriptors.TypeReference.Create(
86+
typeInspector.GetTypeRef(typeof(global::System.Collections.Generic.KeyValuePair<int, string>), HotChocolate.Types.TypeContext.Output),
87+
new global::HotChocolate.Language.NonNullTypeNode(new global::HotChocolate.Language.ListTypeNode(new global::HotChocolate.Language.NonNullTypeNode(new global::HotChocolate.Language.NamedTypeNode("global__System_Collections_Generic_KeyValuePairOfintAndstring")))));
88+
configuration.ResultType = typeof(global::System.Collections.Generic.Dictionary<int, string?>);
89+
90+
configuration.SetSourceGeneratorFlags();
91+
92+
configuration.Resolvers = context.Resolvers.GetStuffAsync();
93+
configuration.ResultPostProcessor = global::HotChocolate.Execution.ListPostProcessor<global::System.Collections.Generic.KeyValuePair<int, string>>.Default;
94+
},
95+
(Resolvers: resolvers, ThisType: thisType));
96+
97+
Configure(descriptor);
98+
}
99+
100+
static partial void Configure(global::HotChocolate.Types.IObjectTypeDescriptor descriptor);
101+
102+
private sealed class __Resolvers
103+
{
104+
public HotChocolate.Resolvers.FieldResolverDelegates GetStuffAsync()
105+
=> new global::HotChocolate.Resolvers.FieldResolverDelegates(resolver: GetStuffAsync);
106+
107+
private async global::System.Threading.Tasks.ValueTask<global::System.Object?> GetStuffAsync(global::HotChocolate.Resolvers.IResolverContext context)
108+
{
109+
var args0 = context.RequestAborted;
110+
var result = await global::TestNamespace.Query.GetStuffAsync(args0);
111+
return result;
112+
}
113+
}
114+
}
115+
}
116+
117+
118+
```
119+

0 commit comments

Comments
 (0)