Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 41 additions & 35 deletions rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpParser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6308,7 +6308,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy
// Skip the . in the member binding (it's part of ?.)
var nullSafe = AdvancePastDotWithNullSafe(memberBinding.OperatorToken);

// Parse the method name with NullSafe marker
// Parse the method name NullSafe marker goes on the MethodInvocation
Identifier name;
JContainer<Expression>? typeParameters = null;

Expand All @@ -6319,7 +6319,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy
name = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([nullSafe]),
Markers.Empty,
[],
genericName.Identifier.Text,
null,
Expand All @@ -6334,7 +6334,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy
name = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([nullSafe]),
Markers.Empty,
[],
memberBinding.Name.Identifier.Text,
null,
Expand All @@ -6347,7 +6347,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy
return new MethodInvocation(
Guid.NewGuid(),
prefix,
Markers.Empty,
Markers.Build([nullSafe]),
new JRightPadded<Expression>(targetExpr, operatorSpace, Markers.Empty),
name,
typeParameters,
Expand Down Expand Up @@ -6433,7 +6433,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy
var name = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([nullSafe]),
Markers.Empty,
[],
memberBinding.Name.Identifier.Text,
null,
Expand All @@ -6443,7 +6443,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy
return new FieldAccess(
Guid.NewGuid(),
prefix,
Markers.Empty,
Markers.Build([nullSafe]),
targetExpr,
new JLeftPadded<Identifier>(operatorSpace, name),
null
Expand Down Expand Up @@ -6502,7 +6502,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy
firstName = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([innerNullSafe]),
Markers.Empty,
[],
genericName.Identifier.Text,
null,
Expand All @@ -6517,7 +6517,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy
firstName = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([innerNullSafe]),
Markers.Empty,
[],
innerMemberBinding.Name.Identifier.Text,
null,
Expand All @@ -6530,7 +6530,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy
firstAccess = new MethodInvocation(
Guid.NewGuid(),
Space.Empty,
Markers.Empty,
Markers.Build([innerNullSafe]),
new JRightPadded<Expression>(targetExpr, operatorSpace, Markers.Empty),
firstName,
firstTypeParams,
Expand Down Expand Up @@ -6595,19 +6595,21 @@ private ArrayAccess ParseNullConditionalElementAccess(Space prefix, Expression t
var closeBracketSpace = ExtractSpaceBefore(argList.CloseBracketToken);
_cursor = argList.CloseBracketToken.Span.End;

// Combine operatorSpace (before ?) with bracketPrefix (before [) into dimension prefix
// The NullSafe marker on dimension tells printer to print ?[
var combinedPrefix = CombineSpaces(operatorSpace, bracketPrefix);
// operatorSpace = space before ?, bracketPrefix = space between ? and [
// NullSafe.DotPrefix stores space between ? and [ (like space between ? and . in MI)
var singleDimNs = bracketPrefix.IsEmpty
? NullSafe.Instance
: new NullSafe(Guid.NewGuid(), bracketPrefix);

return new ArrayAccess(
Guid.NewGuid(),
prefix,
Markers.Empty,
Markers.Build([singleDimNs]),
target,
new ArrayDimension(
Guid.NewGuid(),
combinedPrefix,
Markers.Build([NullSafe.Instance]),
operatorSpace,
Markers.Empty,
new JRightPadded<Expression>(index, closeBracketSpace, Markers.Empty)
),
null
Expand All @@ -6633,18 +6635,20 @@ private ArrayAccess ParseNullConditionalElementAccess(Space prefix, Expression t
var currentSpaceBeforeComma = ExtractSpaceBefore(firstSeparator);
_cursor = firstSeparator.Span.End;

var combinedFirstPrefix = CombineSpaces(operatorSpace, firstBracketPrefix);
var multiDimNs = firstBracketPrefix.IsEmpty
? NullSafe.Instance
: new NullSafe(Guid.NewGuid(), firstBracketPrefix);

// Innermost ArrayAccess with NullSafe marker on dimension
// Innermost ArrayAccess with NullSafe marker on the ArrayAccess
ArrayAccess current = new ArrayAccess(
Guid.NewGuid(),
prefix,
Markers.Empty,
Markers.Build([multiDimNs]),
target,
new ArrayDimension(
Guid.NewGuid(),
combinedFirstPrefix,
Markers.Build([NullSafe.Instance]),
operatorSpace,
Markers.Empty,
new JRightPadded<Expression>(firstIndexExpr, Space.Empty, Markers.Empty)
),
null
Expand Down Expand Up @@ -6741,7 +6745,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target
name = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([nullSafe]),
Markers.Empty,
[],
genericName.Identifier.Text,
null,
Expand All @@ -6756,7 +6760,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target
name = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([nullSafe]),
Markers.Empty,
[],
memberBinding.Name.Identifier.Text,
null,
Expand All @@ -6769,7 +6773,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target
return new MethodInvocation(
Guid.NewGuid(),
prefix,
Markers.Empty,
Markers.Build([nullSafe]),
new JRightPadded<Expression>(target, operatorSpace, Markers.Empty),
name,
typeParameters,
Expand Down Expand Up @@ -6802,7 +6806,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target
var name = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([tbNullSafe]),
Markers.Empty,
[],
terminalBinding.Name.Identifier.Text,
null,
Expand All @@ -6811,7 +6815,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target
return new FieldAccess(
Guid.NewGuid(),
prefix,
Markers.Empty,
Markers.Build([tbNullSafe]),
target,
new JLeftPadded<Identifier>(operatorSpace, name),
null
Expand Down Expand Up @@ -6848,7 +6852,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target
firstName = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([innerNs]),
Markers.Empty,
[],
genericName.Identifier.Text,
null,
Expand All @@ -6863,7 +6867,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target
firstName = new Identifier(
Guid.NewGuid(),
namePrefix,
Markers.Build([innerNs]),
Markers.Empty,
[],
innerMemberBinding.Name.Identifier.Text,
null,
Expand All @@ -6876,7 +6880,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target
firstCall = new MethodInvocation(
Guid.NewGuid(),
Space.Empty,
Markers.Empty,
Markers.Build([innerNs]),
new JRightPadded<Expression>(target, operatorSpace, Markers.Empty),
firstName,
firstTypeParams,
Expand Down Expand Up @@ -6934,7 +6938,7 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat
var namePrefix = ExtractSpaceBefore(genericName.Identifier);
_cursor = genericName.Identifier.Span.End;
name = new Identifier(
Guid.NewGuid(), namePrefix, Markers.Build([mbNs]),
Guid.NewGuid(), namePrefix, Markers.Empty,
[],
genericName.Identifier.Text, null,
null
Expand All @@ -6946,14 +6950,14 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat
var namePrefix = ExtractSpaceBefore(memberBinding.Name.Identifier);
_cursor = memberBinding.Name.Identifier.Span.End;
name = new Identifier(
Guid.NewGuid(), namePrefix, Markers.Build([mbNs]),
Guid.NewGuid(), namePrefix, Markers.Empty,
[],
memberBinding.Name.Identifier.Text, null,
null
);
}
return new FieldAccess(
Guid.NewGuid(), Space.Empty, Markers.Empty,
Guid.NewGuid(), Space.Empty, Markers.Build([mbNs]),
target, new JLeftPadded<Identifier>(operatorSpace, name), null);
}

Expand Down Expand Up @@ -6997,16 +7001,17 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat
JRightPadded<Expression> select;
Identifier name;
JContainer<Expression>? typeParams = null;
NullSafe? nullSafeMarker = null;

if (invocation.Expression is MemberBindingExpressionSyntax binding)
{
var bNs = AdvancePastDotWithNullSafe(binding.OperatorToken);
nullSafeMarker = AdvancePastDotWithNullSafe(binding.OperatorToken);
if (binding.Name is GenericNameSyntax gn)
{
var np = ExtractSpaceBefore(gn.Identifier);
_cursor = gn.Identifier.Span.End;
name = new Identifier(
Guid.NewGuid(), np, Markers.Build([bNs]),
Guid.NewGuid(), np, Markers.Empty,
[],
gn.Identifier.Text, null,
null
Expand All @@ -7018,7 +7023,7 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat
var np = ExtractSpaceBefore(binding.Name.Identifier);
_cursor = binding.Name.Identifier.Span.End;
name = new Identifier(
Guid.NewGuid(), np, Markers.Build([bNs]),
Guid.NewGuid(), np, Markers.Empty,
[],
binding.Name.Identifier.Text, null,
null
Expand Down Expand Up @@ -7064,7 +7069,8 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat

var args = ParseArgumentList(invocation.ArgumentList);
return new MethodInvocation(
Guid.NewGuid(), Space.Empty, Markers.Empty,
Guid.NewGuid(), Space.Empty,
nullSafeMarker != null ? Markers.Build([nullSafeMarker]) : Markers.Empty,
select, name, typeParams, args, null);
}

Expand Down
31 changes: 21 additions & 10 deletions rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpPrinter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -250,8 +250,8 @@ public override J VisitFieldAccess(FieldAccess fieldAccess, PrintOutputCapture<P
// Check for PointerMemberAccess marker on target PointerDereference - if present, print -> instead of .
var isPointerMemberAccess = fieldAccess.Target is PointerDereference pd
&& pd.Markers.FindFirst<PointerMemberAccess>() != null;
// Check for NullSafe marker on the name - if present, print ?. instead of .
var nullSafe = fieldAccess.Name.Element.Markers.FindFirst<NullSafe>();
// Check for NullSafe marker on the FieldAccess - if present, print ?. instead of .
var nullSafe = fieldAccess.Markers.FindFirst<NullSafe>();
if (isPointerMemberAccess)
{
p.Append("->");
Expand Down Expand Up @@ -356,9 +356,13 @@ private void PrintArrayAccessWithoutClosingBracket(ArrayAccess aa, PrintOutputCa
VisitSpace(aa.Prefix, p);
Visit(aa.Indexed, p);
VisitSpace(aa.Dimension.Prefix, p);
// Check for NullSafe marker - if present, print ?[ instead of [
var isNullSafe = aa.Dimension.Markers.FindFirst<NullSafe>() != null;
p.Append(isNullSafe ? "?[" : "[");
var nullSafe = aa.Markers.FindFirst<NullSafe>();
if (nullSafe != null)
{
p.Append('?');
VisitSpace(nullSafe.DotPrefix, p);
}
p.Append('[');
Visit(aa.Dimension.Index.Element, p);
VisitSpace(aa.Dimension.Index.After, p);
// Don't print ] - parent will
Expand All @@ -368,9 +372,16 @@ private void PrintArrayAccessWithoutClosingBracket(ArrayAccess aa, PrintOutputCa
public override J VisitArrayDimension(ArrayDimension dimension, PrintOutputCapture<P> p)
{
BeforeSyntax(dimension, p);
// Check for NullSafe marker - if present, print ?[ instead of [
var isNullSafe = dimension.Markers.FindFirst<NullSafe>() != null;
p.Append(isNullSafe ? "?[" : "[");
// NullSafe marker lives on the parent ArrayAccess — check via cursor
// Dimension prefix holds space before ?, NullSafe.DotPrefix holds space between ? and [
var nullSafe = Cursor.Value is ArrayAccess aa
? aa.Markers.FindFirst<NullSafe>() : null;
if (nullSafe != null)
{
p.Append('?');
VisitSpace(nullSafe.DotPrefix, p);
}
p.Append('[');
Visit(dimension.Index.Element, p);
VisitSpace(dimension.Index.After, p);
p.Append(']');
Expand All @@ -393,9 +404,9 @@ public override J VisitMethodInvocation(MethodInvocation mi, PrintOutputCapture<
// For delegate invocation, skip the dot and name (it's syntactic sugar for .Invoke())
if (!isDelegateInvocation)
{
// Check for NullSafe marker on the name - if present, print ?. instead of .
// Check for NullSafe marker on the MethodInvocation - if present, print ?. instead of .
// Check for PointerMemberAccess marker on select PointerDereference - if present, print -> instead of .
var nullSafe = mi.Name.Markers.FindFirst<NullSafe>();
var nullSafe = mi.Markers.FindFirst<NullSafe>();
var isPointerDeref = mi.Select.Element is PointerDereference selectPd
&& selectPd.Markers.FindFirst<PointerMemberAccess>() != null;
if (isPointerDeref)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,9 @@ public J GetTree()

var comparator = new PatternMatchingComparator(_captures);
var captured = comparator.Match(patternTree, tree, cursor);
return captured != null ? new MatchResult(captured) : null;
if (captured == null) return null;
var nullSafe = comparator.NullSafeBindings;
return new MatchResult(captured, nullSafe.Count > 0 ? nullSafe : null);
}

private bool IsCapturePlaceholder(J node)
Expand Down
12 changes: 11 additions & 1 deletion rewrite-csharp/csharp/OpenRewrite/CSharp/Template/MatchResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@ namespace OpenRewrite.CSharp.Template;
public sealed class MatchResult
{
private readonly Dictionary<string, object> _captures;
private readonly IReadOnlyDictionary<string, NullSafe>? _nullSafeCaptures;

internal MatchResult(Dictionary<string, object> captures)
internal MatchResult(Dictionary<string, object> captures,
IReadOnlyDictionary<string, NullSafe>? nullSafeCaptures = null)
{
_captures = captures;
_nullSafeCaptures = nullSafeCaptures;
}

/// <summary>
Expand Down Expand Up @@ -121,5 +124,12 @@ public IReadOnlyList<T> GetList<T>(string name) where T : class, J
/// </summary>
public bool Has(ICapture capture) => _captures.ContainsKey(capture.Name);

/// <summary>
/// Get the NullSafe marker associated with a capture, if any.
/// Present when the capture was the Select of a null-conditional MI/FA in the matched tree.
/// </summary>
internal NullSafe? GetNullSafe(string name) =>
_nullSafeCaptures != null && _nullSafeCaptures.TryGetValue(name, out var ns) ? ns : null;

internal Dictionary<string, object> AsDict() => _captures;
}
Loading