@@ -142,11 +142,24 @@ private static async Task<Document> ConvertWebHostCreateDefaultBuilder(Document
142142
143143 // Find the full expression chain that contains this invocation
144144 var fullExpression = GetFullExpressionChain ( invocation ) ;
145-
145+
146+ SyntaxTriviaList leadingTrivia = fullExpression . GetLeadingTrivia ( ) ;
147+ if ( ! fullExpression . HasLeadingTrivia )
148+ {
149+ // Try to find some leading trivia from a parent
150+ // e.g. using (var host = WebHost.CreateDefaultBuilder(args)) would not have leading trivia on the WebHost call
151+ var parent = fullExpression . Parent ;
152+ while ( parent != null && ! parent . HasLeadingTrivia )
153+ {
154+ parent = parent . Parent ;
155+ }
156+ leadingTrivia = parent ? . GetLeadingTrivia ( ) ?? SyntaxFactory . TriviaList ( ) ;
157+ }
158+
146159 // Transform the entire expression
147- var transformedExpression = TransformExpression ( fullExpression ) ;
160+ var transformedExpression = TransformExpression ( fullExpression , leadingTrivia ) ;
148161
149- var newRoot = root . ReplaceNode ( fullExpression , transformedExpression . WithLeadingTrivia ( fullExpression . GetLeadingTrivia ( ) ) ) ;
162+ var newRoot = root . ReplaceNode ( fullExpression , transformedExpression ) ;
150163
151164 // Add the required using statement if not already present
152165 if ( root is CompilationUnitSyntax compilationUnit )
@@ -209,21 +222,22 @@ private static BlockSyntax TransformMethodBody(BlockSyntax body)
209222
210223 private static ArrowExpressionClauseSyntax TransformExpressionBody ( ArrowExpressionClauseSyntax expressionBody )
211224 {
212- var transformedExpression = TransformExpression ( expressionBody . Expression ) ;
225+ var transformedExpression = TransformExpression ( expressionBody . Expression , expressionBody . Expression . GetLeadingTrivia ( ) ) ;
213226 return expressionBody . WithExpression ( transformedExpression ) ;
214227 }
215228
216229 private static StatementSyntax TransformStatement ( StatementSyntax statement )
217230 {
218231 if ( statement is ReturnStatementSyntax returnStatement && returnStatement . Expression != null )
219232 {
220- var transformedExpression = TransformExpression ( returnStatement . Expression ) ;
233+ var transformedExpression = TransformExpression ( returnStatement . Expression ,
234+ new SyntaxTriviaList ( SyntaxFactory . ElasticCarriageReturnLineFeed ) . AddRange ( returnStatement . GetLeadingTrivia ( ) ) . Add ( SyntaxFactory . ElasticTab ) ) ;
221235 return returnStatement . WithExpression ( transformedExpression ) ;
222236 }
223237 return statement ;
224238 }
225239
226- private static ExpressionSyntax TransformExpression ( ExpressionSyntax expression )
240+ private static ExpressionSyntax TransformExpression ( ExpressionSyntax expression , SyntaxTriviaList leadingTrivia )
227241 {
228242 // Transform WebHost.CreateDefaultBuilder(args).ConfigureServices(...).UseStartup<Startup>()
229243 // to Host.CreateDefaultBuilder(args).ConfigureWebHostDefaults(webBuilder => webBuilder.ConfigureServices(...).UseStartup<Startup>())
@@ -242,26 +256,27 @@ private static ExpressionSyntax TransformExpression(ExpressionSyntax expression)
242256 . WithArgumentList ( webHostCreateCall . ArgumentList ) ;
243257
244258 // Create the webBuilder expression chain
245- var webBuilderChain = CreateWebBuilderChain ( chainedCalls ) ;
259+ var webBuilderChain = CreateWebBuilderChain ( chainedCalls , leadingTrivia ) ;
246260
247261 // Create the ConfigureWebHostDefaults lambda with proper formatting
248262 var lambda = SyntaxFactory . SimpleLambdaExpression (
249263 SyntaxFactory . Parameter ( SyntaxFactory . Identifier ( "webBuilder" ) ) ,
250- webBuilderChain )
251- . WithArrowToken ( SyntaxFactory . Token ( SyntaxKind . EqualsGreaterThanToken ) ) ;
264+ webBuilderChain ) ;
252265
253266 // Create Host.CreateDefaultBuilder().ConfigureWebHostDefaults(...)
254267 var configureCall = SyntaxFactory . InvocationExpression (
255268 SyntaxFactory . MemberAccessExpression (
256269 SyntaxKind . SimpleMemberAccessExpression ,
257- hostCreateCall . WithTrailingTrivia ( ) ,
270+ hostCreateCall . WithLeadingTrivia ( leadingTrivia ) ,
258271 SyntaxFactory . IdentifierName ( "ConfigureWebHostDefaults" ) )
259272 . WithOperatorToken ( SyntaxFactory . Token ( SyntaxKind . DotToken )
260- . WithLeadingTrivia ( SyntaxFactory . ElasticCarriageReturnLineFeed )
261- . WithTrailingTrivia ( ) ) )
273+ . WithLeadingTrivia ( leadingTrivia )
274+ ) )
262275 . WithArgumentList ( SyntaxFactory . ArgumentList (
263276 SyntaxFactory . SingletonSeparatedList (
264- SyntaxFactory . Argument ( lambda ) ) ) ) ;
277+ SyntaxFactory . Argument ( lambda ) ) ) )
278+ // Adds new line and indentation for remaining chain calls e.g. Build(), Run(), etc.
279+ . WithTrailingTrivia ( new SyntaxTriviaList ( SyntaxFactory . ElasticCarriageReturnLineFeed ) . AddRange ( leadingTrivia ) ) ;
265280
266281 // If there's a remaining chain (like .Build()), append it
267282 ExpressionSyntax result = configureCall ;
@@ -276,8 +291,7 @@ private static ExpressionSyntax TransformExpression(ExpressionSyntax expression)
276291 }
277292 }
278293
279- return result . WithTrailingTrivia ( expression . GetTrailingTrivia ( ) )
280- . WithLeadingTrivia ( expression . GetLeadingTrivia ( ) ) ;
294+ return result ;
281295 }
282296
283297 // Handle standalone WebHost.CreateDefaultBuilder without chaining
@@ -320,9 +334,9 @@ invocation.Expression is MemberAccessExpressionSyntax memberAccess &&
320334 "UseDefaultServiceProvider"
321335 } ;
322336
323- private static ( InvocationExpressionSyntax ? webHostCreateCall , List < ( SimpleNameSyntax methodName , ArgumentListSyntax arguments ) > chainedCalls , ExpressionSyntax ? remainingChain ) ExtractWebHostBuilderChain ( ExpressionSyntax expression )
337+ private static ( InvocationExpressionSyntax ? webHostCreateCall , List < ( SimpleNameSyntax methodName , ArgumentListSyntax arguments , SyntaxTriviaList leadingTrivia ) > chainedCalls , ExpressionSyntax ? remainingChain ) ExtractWebHostBuilderChain ( ExpressionSyntax expression )
324338 {
325- var chainedCalls = new List < ( SimpleNameSyntax methodName , ArgumentListSyntax arguments ) > ( ) ;
339+ var chainedCalls = new List < ( SimpleNameSyntax methodName , ArgumentListSyntax arguments , SyntaxTriviaList leadingTrivia ) > ( ) ;
326340 InvocationExpressionSyntax ? webHostCreateCall = null ;
327341 ExpressionSyntax ? remainingChain = null ;
328342
@@ -373,42 +387,55 @@ memberAccess.Expression is IdentifierNameSyntax identifier &&
373387 // Check if it's a WebHostBuilder method that should go inside ConfigureWebHostDefaults
374388 if ( WebHostBuilderMethods . Contains ( methodName . Identifier . ValueText ) )
375389 {
376- chainedCalls . Add ( ( methodName , arguments ) ) ;
390+ chainedCalls . Add ( ( methodName , arguments , invocationExpr . GetLeadingTrivia ( ) ) ) ;
377391 }
378392 else
379393 {
380394 // This method should remain outside the lambda (like Build(), Run(), etc.)
381395 nonWebHostMethods . Add ( ( methodName , arguments , invocationExpr ) ) ;
382- break ; // Stop processing once we hit a non-WebHostBuilder method
396+ // Add any remaining methods to the non-WebHostBuilder list
397+ nonWebHostMethods . AddRange ( methodCalls ) ;
398+ methodCalls . Clear ( ) ;
399+ // Stop processing once we hit a non-WebHostBuilder method
400+ break ;
383401 }
384402 }
385403 }
386-
404+
387405 // Build the remaining chain from non-WebHostBuilder methods
388406 if ( nonWebHostMethods . Count > 0 )
389407 {
390408 // Create a placeholder for the Host.CreateDefaultBuilder().ConfigureWebHostDefaults(...) call
391409 // This will be replaced later, but we need something to chain the remaining methods to
392410 var placeholder = SyntaxFactory . IdentifierName ( "HOST_PLACEHOLDER" ) ;
393-
411+
394412 // Chain the remaining methods
395413 ExpressionSyntax current = placeholder ;
396- foreach ( var ( methodName , arguments , _ ) in nonWebHostMethods )
414+ foreach ( var ( methodName , arguments , invocation ) in nonWebHostMethods )
397415 {
416+ SyntaxTriviaList leadingTrivia = default ;
417+ if ( invocation . Expression is MemberAccessExpressionSyntax memberAccessExpr )
418+ {
419+ leadingTrivia = memberAccessExpr . Expression . GetLeadingTrivia ( ) ;
420+ }
421+
398422 var memberAccess = SyntaxFactory . MemberAccessExpression (
399423 SyntaxKind . SimpleMemberAccessExpression ,
400- current ,
424+ // Since we're appending method calls,
425+ // we need to add trailing trivia after each one to affect the new method calls formatting
426+ current . WithTrailingTrivia ( current . GetTrailingTrivia ( ) . AddRange ( leadingTrivia ) ) ,
401427 methodName ) ;
428+
402429 current = SyntaxFactory . InvocationExpression ( memberAccess , arguments ) ;
403430 }
404-
431+
405432 remainingChain = current ;
406433 }
407-
434+
408435 return ( webHostCreateCall , chainedCalls , remainingChain ) ;
409436 }
410437
411- private static ExpressionSyntax CreateWebBuilderChain ( List < ( SimpleNameSyntax methodName , ArgumentListSyntax arguments ) > chainedCalls )
438+ private static ExpressionSyntax CreateWebBuilderChain ( List < ( SimpleNameSyntax methodName , ArgumentListSyntax arguments , SyntaxTriviaList leadingTrivia ) > chainedCalls , SyntaxTriviaList leadingTrivia2 )
412439 {
413440 if ( chainedCalls . Count == 0 )
414441 {
@@ -421,16 +448,25 @@ private static ExpressionSyntax CreateWebBuilderChain(List<(SimpleNameSyntax met
421448 // Chain all the method calls with proper formatting
422449 for ( int i = 0 ; i < chainedCalls . Count ; i ++ )
423450 {
424- var ( methodName , arguments ) = chainedCalls [ i ] ;
425-
451+ var ( methodName , arguments , leadingTrivia ) = chainedCalls [ i ] ;
452+
426453 var memberAccess = SyntaxFactory . MemberAccessExpression (
427454 SyntaxKind . SimpleMemberAccessExpression ,
428455 current ,
429456 methodName ) ; // Use the SimpleNameSyntax directly to preserve generics
430457
431458 current = SyntaxFactory . InvocationExpression ( memberAccess , arguments ) ;
459+ current = current . WithTrailingTrivia ( ) ;
460+
461+ if ( i < chainedCalls . Count - 1 )
462+ {
463+ // Add a line break and indentation for all but the last method call
464+ var triviaList = new SyntaxTriviaList ( SyntaxFactory . ElasticCarriageReturnLineFeed ) . AddRange ( leadingTrivia ) ;
465+ triviaList = triviaList . Add ( SyntaxFactory . ElasticTab ) ;
466+ current = current . WithTrailingTrivia ( triviaList ) ;
467+ }
432468 }
433-
469+
434470 return current ;
435471 }
436472}
0 commit comments