diff --git a/rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpParser.cs b/rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpParser.cs index 8e752706aa..b2e05db429 100644 --- a/rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpParser.cs +++ b/rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpParser.cs @@ -6308,7 +6308,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy // Skip the . in the member binding (it's part of ?.) var nullSafe = AdvancePastDotWithNullSafe(memberBinding.OperatorToken); - // Parse the method name with NullSafe marker + // Parse the method name — NullSafe marker goes on the MethodInvocation Identifier name; JContainer? typeParameters = null; @@ -6319,7 +6319,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy name = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([nullSafe]), + Markers.Empty, [], genericName.Identifier.Text, null, @@ -6334,7 +6334,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy name = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([nullSafe]), + Markers.Empty, [], memberBinding.Name.Identifier.Text, null, @@ -6347,7 +6347,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy return new MethodInvocation( Guid.NewGuid(), prefix, - Markers.Empty, + Markers.Build([nullSafe]), new JRightPadded(targetExpr, operatorSpace, Markers.Empty), name, typeParameters, @@ -6433,7 +6433,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy var name = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([nullSafe]), + Markers.Empty, [], memberBinding.Name.Identifier.Text, null, @@ -6443,7 +6443,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy return new FieldAccess( Guid.NewGuid(), prefix, - Markers.Empty, + Markers.Build([nullSafe]), targetExpr, new JLeftPadded(operatorSpace, name), null @@ -6502,7 +6502,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy firstName = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([innerNullSafe]), + Markers.Empty, [], genericName.Identifier.Text, null, @@ -6517,7 +6517,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy firstName = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([innerNullSafe]), + Markers.Empty, [], innerMemberBinding.Name.Identifier.Text, null, @@ -6530,7 +6530,7 @@ public override J VisitConditionalAccessExpression(ConditionalAccessExpressionSy firstAccess = new MethodInvocation( Guid.NewGuid(), Space.Empty, - Markers.Empty, + Markers.Build([innerNullSafe]), new JRightPadded(targetExpr, operatorSpace, Markers.Empty), firstName, firstTypeParams, @@ -6595,19 +6595,21 @@ private ArrayAccess ParseNullConditionalElementAccess(Space prefix, Expression t var closeBracketSpace = ExtractSpaceBefore(argList.CloseBracketToken); _cursor = argList.CloseBracketToken.Span.End; - // Combine operatorSpace (before ?) with bracketPrefix (before [) into dimension prefix - // The NullSafe marker on dimension tells printer to print ?[ - var combinedPrefix = CombineSpaces(operatorSpace, bracketPrefix); + // operatorSpace = space before ?, bracketPrefix = space between ? and [ + // NullSafe.DotPrefix stores space between ? and [ (like space between ? and . in MI) + var singleDimNs = bracketPrefix.IsEmpty + ? NullSafe.Instance + : new NullSafe(Guid.NewGuid(), bracketPrefix); return new ArrayAccess( Guid.NewGuid(), prefix, - Markers.Empty, + Markers.Build([singleDimNs]), target, new ArrayDimension( Guid.NewGuid(), - combinedPrefix, - Markers.Build([NullSafe.Instance]), + operatorSpace, + Markers.Empty, new JRightPadded(index, closeBracketSpace, Markers.Empty) ), null @@ -6633,18 +6635,20 @@ private ArrayAccess ParseNullConditionalElementAccess(Space prefix, Expression t var currentSpaceBeforeComma = ExtractSpaceBefore(firstSeparator); _cursor = firstSeparator.Span.End; - var combinedFirstPrefix = CombineSpaces(operatorSpace, firstBracketPrefix); + var multiDimNs = firstBracketPrefix.IsEmpty + ? NullSafe.Instance + : new NullSafe(Guid.NewGuid(), firstBracketPrefix); - // Innermost ArrayAccess with NullSafe marker on dimension + // Innermost ArrayAccess with NullSafe marker on the ArrayAccess ArrayAccess current = new ArrayAccess( Guid.NewGuid(), prefix, - Markers.Empty, + Markers.Build([multiDimNs]), target, new ArrayDimension( Guid.NewGuid(), - combinedFirstPrefix, - Markers.Build([NullSafe.Instance]), + operatorSpace, + Markers.Empty, new JRightPadded(firstIndexExpr, Space.Empty, Markers.Empty) ), null @@ -6741,7 +6745,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target name = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([nullSafe]), + Markers.Empty, [], genericName.Identifier.Text, null, @@ -6756,7 +6760,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target name = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([nullSafe]), + Markers.Empty, [], memberBinding.Name.Identifier.Text, null, @@ -6769,7 +6773,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target return new MethodInvocation( Guid.NewGuid(), prefix, - Markers.Empty, + Markers.Build([nullSafe]), new JRightPadded(target, operatorSpace, Markers.Empty), name, typeParameters, @@ -6802,7 +6806,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target var name = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([tbNullSafe]), + Markers.Empty, [], terminalBinding.Name.Identifier.Text, null, @@ -6811,7 +6815,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target return new FieldAccess( Guid.NewGuid(), prefix, - Markers.Empty, + Markers.Build([tbNullSafe]), target, new JLeftPadded(operatorSpace, name), null @@ -6848,7 +6852,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target firstName = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([innerNs]), + Markers.Empty, [], genericName.Identifier.Text, null, @@ -6863,7 +6867,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target firstName = new Identifier( Guid.NewGuid(), namePrefix, - Markers.Build([innerNs]), + Markers.Empty, [], innerMemberBinding.Name.Identifier.Text, null, @@ -6876,7 +6880,7 @@ private Expression ProcessConditionalWhenNotNull(Space prefix, Expression target firstCall = new MethodInvocation( Guid.NewGuid(), Space.Empty, - Markers.Empty, + Markers.Build([innerNs]), new JRightPadded(target, operatorSpace, Markers.Empty), firstName, firstTypeParams, @@ -6934,7 +6938,7 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat var namePrefix = ExtractSpaceBefore(genericName.Identifier); _cursor = genericName.Identifier.Span.End; name = new Identifier( - Guid.NewGuid(), namePrefix, Markers.Build([mbNs]), + Guid.NewGuid(), namePrefix, Markers.Empty, [], genericName.Identifier.Text, null, null @@ -6946,14 +6950,14 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat var namePrefix = ExtractSpaceBefore(memberBinding.Name.Identifier); _cursor = memberBinding.Name.Identifier.Span.End; name = new Identifier( - Guid.NewGuid(), namePrefix, Markers.Build([mbNs]), + Guid.NewGuid(), namePrefix, Markers.Empty, [], memberBinding.Name.Identifier.Text, null, null ); } return new FieldAccess( - Guid.NewGuid(), Space.Empty, Markers.Empty, + Guid.NewGuid(), Space.Empty, Markers.Build([mbNs]), target, new JLeftPadded(operatorSpace, name), null); } @@ -6997,16 +7001,17 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat JRightPadded select; Identifier name; JContainer? typeParams = null; + NullSafe? nullSafeMarker = null; if (invocation.Expression is MemberBindingExpressionSyntax binding) { - var bNs = AdvancePastDotWithNullSafe(binding.OperatorToken); + nullSafeMarker = AdvancePastDotWithNullSafe(binding.OperatorToken); if (binding.Name is GenericNameSyntax gn) { var np = ExtractSpaceBefore(gn.Identifier); _cursor = gn.Identifier.Span.End; name = new Identifier( - Guid.NewGuid(), np, Markers.Build([bNs]), + Guid.NewGuid(), np, Markers.Empty, [], gn.Identifier.Text, null, null @@ -7018,7 +7023,7 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat var np = ExtractSpaceBefore(binding.Name.Identifier); _cursor = binding.Name.Identifier.Span.End; name = new Identifier( - Guid.NewGuid(), np, Markers.Build([bNs]), + Guid.NewGuid(), np, Markers.Empty, [], binding.Name.Identifier.Text, null, null @@ -7064,7 +7069,8 @@ private Expression ParseConditionalAccessSegment(Expression target, Space operat var args = ParseArgumentList(invocation.ArgumentList); return new MethodInvocation( - Guid.NewGuid(), Space.Empty, Markers.Empty, + Guid.NewGuid(), Space.Empty, + nullSafeMarker != null ? Markers.Build([nullSafeMarker]) : Markers.Empty, select, name, typeParams, args, null); } diff --git a/rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpPrinter.cs b/rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpPrinter.cs index 94f8aa9a5d..47decbdf7e 100644 --- a/rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpPrinter.cs +++ b/rewrite-csharp/csharp/OpenRewrite/CSharp/CSharpPrinter.cs @@ -250,8 +250,8 @@ public override J VisitFieldAccess(FieldAccess fieldAccess, PrintOutputCapture

instead of . var isPointerMemberAccess = fieldAccess.Target is PointerDereference pd && pd.Markers.FindFirst() != null; - // Check for NullSafe marker on the name - if present, print ?. instead of . - var nullSafe = fieldAccess.Name.Element.Markers.FindFirst(); + // Check for NullSafe marker on the FieldAccess - if present, print ?. instead of . + var nullSafe = fieldAccess.Markers.FindFirst(); if (isPointerMemberAccess) { p.Append("->"); @@ -356,9 +356,13 @@ private void PrintArrayAccessWithoutClosingBracket(ArrayAccess aa, PrintOutputCa VisitSpace(aa.Prefix, p); Visit(aa.Indexed, p); VisitSpace(aa.Dimension.Prefix, p); - // Check for NullSafe marker - if present, print ?[ instead of [ - var isNullSafe = aa.Dimension.Markers.FindFirst() != null; - p.Append(isNullSafe ? "?[" : "["); + var nullSafe = aa.Markers.FindFirst(); + if (nullSafe != null) + { + p.Append('?'); + VisitSpace(nullSafe.DotPrefix, p); + } + p.Append('['); Visit(aa.Dimension.Index.Element, p); VisitSpace(aa.Dimension.Index.After, p); // Don't print ] - parent will @@ -368,9 +372,16 @@ private void PrintArrayAccessWithoutClosingBracket(ArrayAccess aa, PrintOutputCa public override J VisitArrayDimension(ArrayDimension dimension, PrintOutputCapture

p) { BeforeSyntax(dimension, p); - // Check for NullSafe marker - if present, print ?[ instead of [ - var isNullSafe = dimension.Markers.FindFirst() != null; - p.Append(isNullSafe ? "?[" : "["); + // NullSafe marker lives on the parent ArrayAccess — check via cursor + // Dimension prefix holds space before ?, NullSafe.DotPrefix holds space between ? and [ + var nullSafe = Cursor.Value is ArrayAccess aa + ? aa.Markers.FindFirst() : null; + if (nullSafe != null) + { + p.Append('?'); + VisitSpace(nullSafe.DotPrefix, p); + } + p.Append('['); Visit(dimension.Index.Element, p); VisitSpace(dimension.Index.After, p); p.Append(']'); @@ -393,9 +404,9 @@ public override J VisitMethodInvocation(MethodInvocation mi, PrintOutputCapture< // For delegate invocation, skip the dot and name (it's syntactic sugar for .Invoke()) if (!isDelegateInvocation) { - // Check for NullSafe marker on the name - if present, print ?. instead of . + // Check for NullSafe marker on the MethodInvocation - if present, print ?. instead of . // Check for PointerMemberAccess marker on select PointerDereference - if present, print -> instead of . - var nullSafe = mi.Name.Markers.FindFirst(); + var nullSafe = mi.Markers.FindFirst(); var isPointerDeref = mi.Select.Element is PointerDereference selectPd && selectPd.Markers.FindFirst() != null; if (isPointerDeref) diff --git a/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/CSharpPattern.cs b/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/CSharpPattern.cs index b7081d973e..68d439a969 100644 --- a/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/CSharpPattern.cs +++ b/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/CSharpPattern.cs @@ -177,7 +177,9 @@ public J GetTree() var comparator = new PatternMatchingComparator(_captures); var captured = comparator.Match(patternTree, tree, cursor); - return captured != null ? new MatchResult(captured) : null; + if (captured == null) return null; + var nullSafe = comparator.NullSafeBindings; + return new MatchResult(captured, nullSafe.Count > 0 ? nullSafe : null); } private bool IsCapturePlaceholder(J node) diff --git a/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/MatchResult.cs b/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/MatchResult.cs index a825d653d1..30616f9b6b 100644 --- a/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/MatchResult.cs +++ b/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/MatchResult.cs @@ -24,10 +24,13 @@ namespace OpenRewrite.CSharp.Template; public sealed class MatchResult { private readonly Dictionary _captures; + private readonly IReadOnlyDictionary? _nullSafeCaptures; - internal MatchResult(Dictionary captures) + internal MatchResult(Dictionary captures, + IReadOnlyDictionary? nullSafeCaptures = null) { _captures = captures; + _nullSafeCaptures = nullSafeCaptures; } ///

@@ -121,5 +124,12 @@ public IReadOnlyList GetList(string name) where T : class, J /// public bool Has(ICapture capture) => _captures.ContainsKey(capture.Name); + /// + /// Get the NullSafe marker associated with a capture, if any. + /// Present when the capture was the Select of a null-conditional MI/FA in the matched tree. + /// + internal NullSafe? GetNullSafe(string name) => + _nullSafeCaptures != null && _nullSafeCaptures.TryGetValue(name, out var ns) ? ns : null; + internal Dictionary AsDict() => _captures; } diff --git a/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/PatternMatchingComparator.cs b/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/PatternMatchingComparator.cs index 02046bc074..ee7a2d7fd4 100644 --- a/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/PatternMatchingComparator.cs +++ b/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/PatternMatchingComparator.cs @@ -30,6 +30,7 @@ internal class PatternMatchingComparator { private readonly IReadOnlyDictionary _captures; private readonly Dictionary _bindings = new(); + private readonly Dictionary _nullSafeBindings = new(); public PatternMatchingComparator(IReadOnlyDictionary captures) { @@ -43,9 +44,18 @@ public PatternMatchingComparator(IReadOnlyDictionary captures) public Dictionary? Match(J pattern, J candidate, Cursor cursor) { _bindings.Clear(); + _nullSafeBindings.Clear(); return MatchNode(pattern, candidate, cursor) ? new Dictionary(_bindings) : null; } + /// + /// After a successful , returns capture names mapped to the + /// marker from the candidate MethodInvocation/FieldAccess + /// when the capture was in the Select position. Used by the template engine to + /// preserve ?. through rewrites. + /// + internal IReadOnlyDictionary NullSafeBindings => _nullSafeBindings; + private bool MatchNode(J pattern, J candidate, Cursor cursor) { // Check if pattern node is a placeholder identifier @@ -73,32 +83,105 @@ private bool MatchNode(J pattern, J candidate, Cursor cursor) if (pattern.GetType() != candidate.GetType()) return MatchCrossType(pattern, candidate, cursor); - // NullSafe marker must match: ?. and . are structurally different - if (TreeHelper.HasNullSafe(pattern) != TreeHelper.HasNullSafe(candidate)) + // NullSafe: a pattern with ?. only matches candidates with ?. + // but a pattern without ?. matches both (asymmetric — patterns are lenient) + if (TreeHelper.HasNullSafe(pattern) && !TreeHelper.HasNullSafe(candidate)) return false; // Semantic matching for method invocations: when both resolve to the same // static method (same declaring type + name), skip receiver comparison. + bool matched; if (pattern is MethodInvocation patMethod && candidate is MethodInvocation candMethod) - return MatchMethodInvocation(patMethod, candMethod, cursor); - - // Generic property-based comparison: iterate all structural properties - // and compare them recursively, skipping formatting/identity fields. - if (pattern is Binary patBin && candidate is Binary candBin) + matched = MatchMethodInvocation(patMethod, candMethod, cursor); + else if (pattern is Binary patBin && candidate is Binary candBin) { - // Save bindings so we can backtrack if direct match fails + // Generic property-based comparison with backtracking for commutative ops var savedBindings = new Dictionary(_bindings); - if (MatchProperties(pattern, candidate, cursor)) - return true; + matched = MatchProperties(pattern, candidate, cursor); + if (!matched) + { + // Restore bindings and try commuted (swapped) operands for == and != + _bindings.Clear(); + foreach (var kvp in savedBindings) + _bindings[kvp.Key] = kvp.Value; + matched = MatchCommutedBinary(patBin, candBin, cursor); + } + } + else + matched = MatchProperties(pattern, candidate, cursor); + + // Record NullSafe associations for captures used as Select in MI/FA + if (matched) + RecordNullSafeForCaptures(pattern, candidate); + + return matched; + } - // Restore bindings and try commuted (swapped) operands for == and != - _bindings.Clear(); - foreach (var kvp in savedBindings) - _bindings[kvp.Key] = kvp.Value; - return MatchCommutedBinary(patBin, candBin, cursor); + /// + /// After a successful match of a MethodInvocation, FieldAccess, or ArrayAccess, + /// check if the candidate has a NullSafe marker that the pattern doesn't have. + /// If so, find the capture placeholder in the Select/Target/Indexed position and + /// record the NullSafe association so the template engine can preserve ?. + /// and ?[ through rewrites. + /// + private void RecordNullSafeForCaptures(J pattern, J candidate) + { + if (pattern is MethodInvocation patMi && candidate is MethodInvocation candMi) + { + var candNullSafe = candMi.Markers.FindFirst(); + if (candNullSafe != null && patMi.Markers.FindFirst() == null) + { + if (FindSelectCaptureName(patMi.Select) is { } captureName) + _nullSafeBindings[captureName] = candNullSafe; + } } + else if (pattern is FieldAccess patFa && candidate is FieldAccess candFa) + { + var candNullSafe = candFa.Markers.FindFirst(); + if (candNullSafe != null && patFa.Markers.FindFirst() == null) + { + if (FindTargetCaptureName(patFa.Target) is { } captureName) + _nullSafeBindings[captureName] = candNullSafe; + } + } + else if (pattern is ArrayAccess patAa && candidate is ArrayAccess candAa) + { + var candNullSafe = candAa.Markers.FindFirst(); + if (candNullSafe != null && patAa.Markers.FindFirst() == null) + { + if (FindTargetCaptureName(patAa.Indexed) is { } captureName) + _nullSafeBindings[captureName] = candNullSafe; + } + } + } - return MatchProperties(pattern, candidate, cursor); + /// + /// Check if a MI's Select (JRightPadded<Expression>?) contains a direct + /// capture placeholder, returning its name if so. + /// + private string? FindSelectCaptureName(JRightPadded? select) + { + if (select?.Element is Identifier ident) + { + var name = Placeholder.FromPlaceholder(ident.SimpleName); + if (name != null && _captures.ContainsKey(name)) + return name; + } + return null; + } + + /// + /// Check if a FieldAccess's Target is a direct capture placeholder. + /// + private string? FindTargetCaptureName(Expression target) + { + if (target is Identifier ident) + { + var name = Placeholder.FromPlaceholder(ident.SimpleName); + if (name != null && _captures.ContainsKey(name)) + return name; + } + return null; } /// diff --git a/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/TemplateEngine.cs b/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/TemplateEngine.cs index c1b19c47a4..302ca69520 100644 --- a/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/TemplateEngine.cs +++ b/rewrite-csharp/csharp/OpenRewrite/CSharp/Template/TemplateEngine.cs @@ -720,6 +720,11 @@ public override J VisitAnnotation(Annotation annotation, int p) public override J VisitMethodInvocation(MethodInvocation mi, int p) { + // Check if the select is a capture placeholder BEFORE substitution + var selectCaptureName = mi.Select?.Element is Identifier selectId + ? Placeholder.FromPlaceholder(selectId.SimpleName) + : null; + mi = (MethodInvocation)base.VisitMethodInvocation(mi, p); // Substitute placeholder in method name position @@ -733,6 +738,14 @@ public override J VisitMethodInvocation(MethodInvocation mi, int p) } } + // Transfer NullSafe from matched tree when the capture was a null-conditional select + if (selectCaptureName != null && mi.Markers.FindFirst() == null) + { + var nullSafe = _values.GetNullSafe(selectCaptureName); + if (nullSafe != null) + mi = mi.WithMarkers(mi.Markers.Add(nullSafe)); + } + // Substitute variadic placeholder in arguments mi = ExpandVariadicArgs(mi); @@ -741,6 +754,11 @@ public override J VisitMethodInvocation(MethodInvocation mi, int p) public override J VisitFieldAccess(FieldAccess fieldAccess, int p) { + // Check if the target is a capture placeholder BEFORE substitution + var targetCaptureName = fieldAccess.Target is Identifier targetId + ? Placeholder.FromPlaceholder(targetId.SimpleName) + : null; + fieldAccess = (FieldAccess)base.VisitFieldAccess(fieldAccess, p); // Substitute placeholder in field name position @@ -756,9 +774,37 @@ public override J VisitFieldAccess(FieldAccess fieldAccess, int p) } } + // Transfer NullSafe from matched tree when the capture was a null-conditional target + if (targetCaptureName != null && fieldAccess.Markers.FindFirst() == null) + { + var nullSafe = _values.GetNullSafe(targetCaptureName); + if (nullSafe != null) + fieldAccess = fieldAccess.WithMarkers(fieldAccess.Markers.Add(nullSafe)); + } + return fieldAccess; } + public override J VisitArrayAccess(ArrayAccess arrayAccess, int p) + { + // Check if the indexed expression is a capture placeholder BEFORE substitution + var indexedCaptureName = arrayAccess.Indexed is Identifier indexedId + ? Placeholder.FromPlaceholder(indexedId.SimpleName) + : null; + + arrayAccess = (ArrayAccess)base.VisitArrayAccess(arrayAccess, p); + + // Transfer NullSafe from matched tree when the capture was a null-conditional indexed expr + if (indexedCaptureName != null && arrayAccess.Markers.FindFirst() == null) + { + var nullSafe = _values.GetNullSafe(indexedCaptureName); + if (nullSafe != null) + arrayAccess = arrayAccess.WithMarkers(arrayAccess.Markers.Add(nullSafe)); + } + + return arrayAccess; + } + /// /// If any argument is a placeholder identifier bound to a variadic capture (list), /// expand it into the argument list. diff --git a/rewrite-csharp/csharp/OpenRewrite/Tests/Template/PatternMatchTests.cs b/rewrite-csharp/csharp/OpenRewrite/Tests/Template/PatternMatchTests.cs index b2ad98f8f3..2770d790c7 100644 --- a/rewrite-csharp/csharp/OpenRewrite/Tests/Template/PatternMatchTests.cs +++ b/rewrite-csharp/csharp/OpenRewrite/Tests/Template/PatternMatchTests.cs @@ -792,14 +792,15 @@ public void NullConditionalPatternDoesNotMatchRegularDotAccess() } [Fact] - public void RegularDotPatternDoesNotMatchNullConditionalAccess() + public void RegularDotPatternMatchesNullConditionalAccess() { var obj = Capture.Of("obj"); RewriteRun( spec => spec.SetRecipe(FindMethodInvocation($"{obj}.ToString()")), CSharp( - // ?. access should NOT match a regular . pattern - "class C { void M() { string s = null; var x = s?.ToString(); } }" + // A pattern without ?. matches both . and ?. access (asymmetric) + "class C { void M() { string s = null; var x = s?.ToString(); } }", + "class C { void M() { string s = null; var x = /*~~>*/s?.ToString(); } }" ) ); } diff --git a/rewrite-csharp/csharp/OpenRewrite/Tests/Template/RewriteRuleTests.cs b/rewrite-csharp/csharp/OpenRewrite/Tests/Template/RewriteRuleTests.cs index 5219541517..2c783efbf4 100644 --- a/rewrite-csharp/csharp/OpenRewrite/Tests/Template/RewriteRuleTests.cs +++ b/rewrite-csharp/csharp/OpenRewrite/Tests/Template/RewriteRuleTests.cs @@ -511,6 +511,155 @@ void M() ); } + // =============================================================== + // NullSafe preservation — ?. marker transfers through Rewrite + // =============================================================== + + [Fact] + public void PreservesNullConditionalInRewrite() + { + var x = Capture.Expression(type: "IEnumerable", typeParameters: ["T"]); + var pred = Capture.Expression(); + + RewriteRun( + spec => spec.SetRecipe(new RewriteRecipe( + CSharpTemplate.Rewrite( + CSharpPattern.Expression($"{x}.Where({pred}).First()"), + CSharpTemplate.Expression($"{x}.First({pred})")))) + .SetReferenceAssemblies(Assemblies.Net90), + CSharp( + """ + using System.Linq; + using System.Collections.Generic; + class Test + { + void M(Dictionary> dict) + { + var result = dict["key"]?.Where(x => x > 0).First(); + } + } + """, + """ + using System.Linq; + using System.Collections.Generic; + class Test + { + void M(Dictionary> dict) + { + var result = dict["key"]?.First(x => x > 0); + } + } + """ + ) + ); + } + + [Fact] + public void PreservesNullConditionalOnFieldAccess() + { + var x = Capture.Expression(); + + RewriteRun( + spec => spec.SetRecipe(new RewriteRecipe( + CSharpTemplate.Rewrite( + CSharpPattern.Expression($"{x}.Length"), + CSharpTemplate.Expression($"{x}.Count")))), + CSharp( + """ + class Test + { + void M(string? s) + { + var n = s?.Length; + } + } + """, + """ + class Test + { + void M(string? s) + { + var n = s?.Count; + } + } + """ + ) + ); + } + + [Fact] + public void PreservesNullConditionalOnElementAccess() + { + var x = Capture.Expression(); + var i = Capture.Expression(); + + RewriteRun( + spec => spec.SetRecipe(new RewriteRecipe( + CSharpTemplate.Rewrite( + CSharpPattern.Expression($"{x}[{i}]"), + CSharpTemplate.Expression($"{x}[{i}]")))), + CSharp( + """ + class Test + { + void M(int[]? arr) + { + var n = arr?[0]; + } + } + """, + """ + class Test + { + void M(int[]? arr) + { + var n = arr?[0]; + } + } + """ + ) + ); + } + + [Fact] + public void NullConditionalNotAddedWhenOriginalHasNone() + { + var x = Capture.Expression(type: "IEnumerable", typeParameters: ["T"]); + var pred = Capture.Expression(); + + RewriteRun( + spec => spec.SetRecipe(new RewriteRecipe( + CSharpTemplate.Rewrite( + CSharpPattern.Expression($"{x}.Where({pred}).First()"), + CSharpTemplate.Expression($"{x}.First({pred})")))) + .SetReferenceAssemblies(Assemblies.Net90), + CSharp( + """ + using System.Linq; + using System.Collections.Generic; + class Test + { + void M(List list) + { + var result = list.Where(x => x > 0).First(); + } + } + """, + """ + using System.Linq; + using System.Collections.Generic; + class Test + { + void M(List list) + { + var result = list.First(x => x > 0); + } + } + """ + ) + ); + } + // =============================================================== // FlattenBlock — multi-statement template spliced into parent // =============================================================== @@ -870,6 +1019,14 @@ public override ITreeVisitor GetVisitor() } } +class RewriteRecipe(CSharpVisitor visitor) : Core.Recipe +{ + public override string DisplayName => "Rewrite recipe"; + public override string Description => "Applies a CSharpTemplate.Rewrite visitor."; + + public override ITreeVisitor GetVisitor() => visitor; +} + class FallbackWithManualVisitorRecipe : Core.Recipe { public override string DisplayName => "Fallback with manual visitor";