Skip to content

Fixed type nullability within DataLoader`s #8482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ public void WriteDataLoaderInterface(
kind is DataLoaderKind.Group
? ": global::GreenDonut.IDataLoader<{0}, {1}[]>"
: ": global::GreenDonut.IDataLoader<{0}, {1}>",
key.ToFullyQualified(),
value.ToFullyQualified());
key.ToFullyQualifiedWithNullRefQualifier(),
value.ToFullyQualifiedWithNullRefQualifier());

_writer.DecreaseIndent();
_writer.WriteIndentedLine("{");
Expand Down Expand Up @@ -97,8 +97,8 @@ public void WriteBeginDataLoaderClass(
kind is DataLoaderKind.Group
? ": global::GreenDonut.DataLoaderBase<{0}, {1}[]>"
: ": global::GreenDonut.DataLoaderBase<{0}, {1}>",
key.ToFullyQualified(),
value.ToFullyQualified());
key.ToFullyQualifiedWithNullRefQualifier(),
value.ToFullyQualifiedWithNullRefQualifier());
if (withInterface)
{
_writer.WriteIndentedLine(", {0}", interfaceName);
Expand Down Expand Up @@ -174,23 +174,20 @@ public void WriteDataLoaderConstructor(
if (lookup.IsTransform)
{
_writer.WriteIndentedLine(
".Create<{0}, {1}{2}, {3}{4}>({5}.{6}, this)",
keyType.ToFullyQualified(),
valueType.ToFullyQualified(),
valueType.PrintNullRefQualifier(),
lookup.Method.Parameters[0].Type.ToFullyQualified(),
lookup.Method.Parameters[0].Type.PrintNullRefQualifier(),
lookup.Method.ContainingType.ToFullyQualified(),
".Create<{0}, {1}, {2}>({3}.{4}, this)",
keyType.ToFullyQualifiedWithNullRefQualifier(),
valueType.ToFullyQualifiedWithNullRefQualifier(),
lookup.Method.Parameters[0].Type.ToFullyQualifiedWithNullRefQualifier(),
lookup.Method.ContainingType.ToFullyQualifiedWithNullRefQualifier(),
lookup.Method.Name);
}
else
{
_writer.WriteIndentedLine(
".Create<{0}, {1}{2}>({3}.{4}, this)",
keyType.ToFullyQualified(),
valueType.ToFullyQualified(),
valueType.PrintNullRefQualifier(),
lookup.Method.ContainingType.ToFullyQualified(),
".Create<{0}, {1}>({2}.{3}, this)",
keyType.ToFullyQualifiedWithNullRefQualifier(),
valueType.ToFullyQualifiedWithNullRefQualifier(),
lookup.Method.ContainingType.ToFullyQualifiedWithNullRefQualifier(),
lookup.Method.Name);
}

Expand Down Expand Up @@ -221,18 +218,18 @@ public void WriteDataLoaderLoadMethod(
_writer.WriteIndentedLine(
"global::{0}<{1}> keys,",
WellKnownTypes.ReadOnlyList,
key.ToFullyQualified());
key.ToFullyQualifiedWithNullRefQualifier());
_writer.WriteIndentedLine(
"global::{0}<{1}<{2}{3}{4}>> results,",
"global::{0}<{1}<{2}>> results,",
WellKnownTypes.Memory,
WellKnownTypes.Result,
value.ToFullyQualified(),
kind is DataLoaderKind.Group ? "[]" : string.Empty,
value.IsValueType ? string.Empty : "?");
kind is DataLoaderKind.Group
? $"{value.ToClassNonNullableFullyQualifiedWithNullRefQualifier()}[]?"
: value.ToNullableFullyQualifiedWithNullRefQualifier());
_writer.WriteIndentedLine(
"global::{0}<{1}{2}> context,",
WellKnownTypes.DataLoaderFetchContext,
value.ToFullyQualified(),
value.ToFullyQualifiedWithNullRefQualifier(),
kind is DataLoaderKind.Group ? "[]" : string.Empty);
_writer.WriteIndentedLine(
"global::{0} ct)",
Expand All @@ -256,14 +253,14 @@ public void WriteDataLoaderLoadMethod(
"var {0} = {1}.GetRequiredService<{2}>();",
parameter.VariableName,
isScoped ? "scope.ServiceProvider" : "_services",
parameter.Type.ToFullyQualified());
parameter.Type.ToFullyQualifiedWithNullRefQualifier());
}
else if (parameter.Kind is DataLoaderParameterKind.SelectorBuilder)
{
_writer.WriteIndentedLine(
"var {0} = context.GetState<{1}>(\"{2}\")",
parameter.VariableName,
parameter.Type.ToFullyQualified(),
parameter.Type.ToFullyQualifiedWithNullRefQualifier(),
parameter.StateKey);
_writer.IncreaseIndent();
_writer.WriteIndentedLine(
Expand All @@ -275,7 +272,7 @@ public void WriteDataLoaderLoadMethod(
_writer.WriteIndentedLine(
"var {0} = context.GetState<{1}>(\"{2}\")",
parameter.VariableName,
parameter.Type.ToFullyQualified(),
parameter.Type.ToFullyQualifiedWithNullRefQualifier(),
parameter.StateKey);
_writer.IncreaseIndent();
_writer.WriteIndentedLine(
Expand All @@ -287,12 +284,12 @@ public void WriteDataLoaderLoadMethod(
_writer.WriteIndentedLine(
"var {0} = context.GetState<{1}>(\"{2}\")",
parameter.VariableName,
parameter.Type.ToFullyQualified(),
parameter.Type.ToFullyQualifiedWithNullRefQualifier(),
parameter.StateKey);
_writer.IncreaseIndent();
_writer.WriteIndentedLine(
"?? {0}.Empty;",
parameter.Type.ToFullyQualified());
parameter.Type.ToFullyQualifiedWithNullRefQualifier());
_writer.DecreaseIndent();
}
else if (parameter.Kind is DataLoaderParameterKind.QueryContext)
Expand All @@ -302,26 +299,26 @@ public void WriteDataLoaderLoadMethod(
parameter.VariableName,
WellKnownTypes.SelectorBuilder,
DataLoaderInfo.Selector,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified());
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualifiedWithNullRefQualifier());
_writer.WriteIndentedLine(
"var {0}_predicate = context.GetState<global::{1}>(\"{2}\")?.TryCompile<{3}>();",
parameter.VariableName,
WellKnownTypes.PredicateBuilder,
DataLoaderInfo.Predicate,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified());
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualifiedWithNullRefQualifier());
_writer.WriteIndentedLine(
"var {0}_sortDefinition = context.GetState<global::{1}<{2}>>(\"{3}\");",
parameter.VariableName,
WellKnownTypes.SortDefinition,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified(),
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualifiedWithNullRefQualifier(),
DataLoaderInfo.Sorting);
_writer.WriteLine();

_writer.WriteIndentedLine(
"var {0} = global::{1}<{2}>.Empty;",
parameter.VariableName,
WellKnownTypes.QueryContext,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified());
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualifiedWithNullRefQualifier());
_writer.WriteIndentedLine(
"if({0}_selector is not null || {0}_predicate is not null || {0}_sortDefinition is not null)",
parameter.VariableName);
Expand All @@ -332,7 +329,7 @@ public void WriteDataLoaderLoadMethod(
+ "{0}_predicate, {0}_sortDefinition);",
parameter.VariableName,
WellKnownTypes.QueryContext,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified());
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualifiedWithNullRefQualifier());
_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
}
Expand All @@ -341,7 +338,7 @@ public void WriteDataLoaderLoadMethod(
_writer.WriteIndentedLine(
"var {0} = context.GetRequiredState<{1}>(\"{2}\");",
parameter.VariableName,
parameter.Type.ToFullyQualified(),
parameter.Type.ToFullyQualifiedWithNullRefQualifier(),
parameter.StateKey);
}
else if (parameter.Kind is DataLoaderParameterKind.ContextData)
Expand All @@ -352,28 +349,26 @@ public void WriteDataLoaderLoadMethod(
var defaultValueString = ConvertDefaultValueToString(defaultValue, parameter.Type);

_writer.WriteIndentedLine(
"var {0} = context.GetStateOrDefault<{1}{2}>(\"{3}\", {4});",
"var {0} = context.GetStateOrDefault<{1}>(\"{2}\", {3});",
parameter.VariableName,
parameter.Type.ToFullyQualified(),
parameter.Type.PrintNullRefQualifier(),
parameter.Type.ToFullyQualifiedWithNullRefQualifier(),
parameter.StateKey,
defaultValueString);
}
else if (parameter.Type.IsNullableType())
{
_writer.WriteIndentedLine(
"var {0} = context.GetState<{1}{2}>(\"{3}\");",
"var {0} = context.GetState<{1}>(\"{2}\");",
parameter.VariableName,
parameter.Type.ToFullyQualified(),
parameter.Type.PrintNullRefQualifier(),
parameter.Type.ToFullyQualifiedWithNullRefQualifier(),
parameter.StateKey);
}
else
{
_writer.WriteIndentedLine(
"var {0} = context.GetRequiredState<{1}>(\"{2}\");",
parameter.VariableName,
parameter.Type.ToFullyQualified(),
parameter.Type.ToFullyQualifiedWithNullRefQualifier(),
parameter.StateKey);
}
}
Expand All @@ -395,9 +390,8 @@ public void WriteDataLoaderLoadMethod(
_writer.WriteIndented("var value = ");
WriteFetchCall(method, containingType, kind, parameters);
_writer.WriteIndentedLine(
"results.Span[i] = Result<{0}{1}>.Resolve(value);",
value.ToFullyQualified(),
value.IsValueType ? string.Empty : "?");
"results.Span[i] = Result<{0}>.Resolve(value);",
value.ToNullableFullyQualifiedWithNullRefQualifier());
}

_writer.WriteIndentedLine("}");
Expand All @@ -407,9 +401,8 @@ public void WriteDataLoaderLoadMethod(
using (_writer.IncreaseIndent())
{
_writer.WriteIndentedLine(
"results.Span[i] = Result<{0}{1}>.Reject(ex);",
value.ToFullyQualified(),
value.IsValueType ? string.Empty : "?");
"results.Span[i] = Result<{0}>.Reject(ex);",
value.ToNullableFullyQualifiedWithNullRefQualifier());
}

_writer.WriteIndentedLine("}");
Expand Down Expand Up @@ -439,14 +432,12 @@ public void WriteDataLoaderLoadMethod(
_writer.WriteIndentedLine(
"global::{0}<{1}> keys,",
WellKnownTypes.ReadOnlyList,
key.ToFullyQualified());
key.ToFullyQualifiedWithNullRefQualifier());
_writer.WriteIndentedLine(
"global::{0}<{1}<{2}{3}{4}>> results,",
"global::{0}<{1}<{2}>> results,",
WellKnownTypes.Span,
WellKnownTypes.Result,
value.ToFullyQualified(),
kind is DataLoaderKind.Group ? "[]" : string.Empty,
value.IsValueType ? string.Empty : "?");
kind is DataLoaderKind.Group ? $"{value.ToClassNonNullableFullyQualifiedWithNullRefQualifier()}[]?" : value.ToNullableFullyQualifiedWithNullRefQualifier());
_writer.WriteIndentedLine(
"{0} resultMap)",
ExtractMapType(method.ReturnType).ToFullyQualifiedWithNullRefQualifier());
Expand Down Expand Up @@ -501,10 +492,9 @@ public void WriteDataLoaderLoadMethod(
using (_writer.IncreaseIndent())
{
_writer.WriteIndentedLine(
"results[i] = global::{0}<{1}{2}>.Resolve(value);",
"results[i] = global::{0}<{1}>.Resolve(value);",
WellKnownTypes.Result,
value.ToFullyQualified(),
value.IsValueType ? string.Empty : "?");
value.ToNullableFullyQualifiedWithNullRefQualifier());
}

_writer.WriteIndentedLine("}");
Expand All @@ -514,11 +504,10 @@ public void WriteDataLoaderLoadMethod(
using (_writer.IncreaseIndent())
{
_writer.WriteIndentedLine(
"results[i] = global::{0}<{1}{2}>.Resolve(default({3}));",
"results[i] = global::{0}<{1}>.Resolve(default({2}));",
WellKnownTypes.Result,
value.ToFullyQualified(),
value.IsValueType ? string.Empty : "?",
value.ToFullyQualified());
value.ToNullableFullyQualifiedWithNullRefQualifier(),
value.IsValueType ? value.ToNullableFullyQualifiedWithNullRefQualifier() : value.ToFullyQualified());
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a special handling to avoid changing all snapshots. Using value.ToNullableFullyQualifiedWithNullRefQualifier() would also always be correct.
In fact, we could also emit a null literal: value.IsNullableType() ? "null" : $"default({value.ToFullyQualified()})"

}

_writer.WriteIndentedLine("}");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ public static string ToFullyQualifiedWithNullRefQualifier(this ITypeSymbol typeS
return typeSymbol.ToDisplayString(format);
}

public static string ToNullableFullyQualifiedWithNullRefQualifier(this ITypeSymbol typeSymbol)
{
var value = typeSymbol.ToFullyQualifiedWithNullRefQualifier();
return value[value.Length - 1] != '?' ? value + "?" : value;
}

public static string ToClassNonNullableFullyQualifiedWithNullRefQualifier(this ITypeSymbol typeSymbol)
{
var value = typeSymbol.ToFullyQualifiedWithNullRefQualifier();
return !typeSymbol.IsValueType && value[value.Length - 1] == '?' ? value.Substring(0, value.Length - 1) : value;
}

public static bool IsParent(this IParameterSymbol parameter)
=> parameter.IsThis
|| parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,4 +698,95 @@ public static Task<IDictionary<int, string>> GetEntityByIdAsync(
}
""").MatchMarkdownAsync();
}

[Fact]
public async Task Generate_DataLoader_NullableAnnotated_AnonymousType_AsKey_MatchesSnapshot()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using HotChocolate;
using GreenDonut;

namespace TestNamespace;

public record Id1();
public record Id2();
public record Stuff();

public class Dataloaders
{
[DataLoader]
public static async Task<ILookup<(Id1, Id2?), Stuff>> GetStuff(
IReadOnlyList<(Id1, Id2?)> keys,
CancellationToken cancellationToken)
{
await Task.CompletedTask;
return null!;
}
}
""").MatchMarkdownAsync();
}

[Fact]
public async Task GenerateSource_BatchDataLoader_ReturnsNullableStruct_MatchesSnapshot()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
using System.Collections.Generic;
using System.Threading;
using System.Threading.Tasks;
using HotChocolate;
using GreenDonut;

namespace TestNamespace;

internal static class TestClass
{
[DataLoader]
public static Task<IReadOnlyDictionary<int, Entity?>> GetEntityByIdAsync(
IReadOnlyList<int> entityIds,
CancellationToken cancellationToken)
=> default!;
}

public struct Entity
{
public int Id { get; set; }
}
""").MatchMarkdownAsync();
}

[Fact]
public async Task GenerateSource_GroupedDataLoader_ReturnsNullableStruct_MatchesSnapshot()
{
await TestHelper.GetGeneratedSourceSnapshot(
"""
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using HotChocolate;
using GreenDonut;

namespace TestNamespace;

internal static class TestClass
{
[DataLoader]
public static Task<ILookup<int, Entity?>> GetEntitiesByIdAsync(
IReadOnlyList<int> entityIds,
CancellationToken cancellationToken)
=> default!;
}

public struct Entity
{
public int Id { get; set; }
}
""").MatchMarkdownAsync();
}
}
Loading