diff --git a/changelog.md b/changelog.md index 38a6770ea0..78257e590a 100644 --- a/changelog.md +++ b/changelog.md @@ -2,9 +2,12 @@ ## Unreleased +## [`v29.6.0`](https://github.com/ignite/cli/releases/tag/v29.6.0) + ### Changes - [#4825](https://github.com/ignite/cli/pull/4825) Bump `cosmos-sdk` version to `v0.53.4` and `ibc-go` version to `v10.4.0`. +- [#4828](https://github.com/ignite/cli/pull/4828) Add remove func call in `xast` package. ## [`v29.5.0`](https://github.com/ignite/cli/releases/tag/v29.5.0) diff --git a/ignite/pkg/xast/function.go b/ignite/pkg/xast/function.go index 5ac7e3cece..92ecdcea4e 100644 --- a/ignite/pkg/xast/function.go +++ b/ignite/pkg/xast/function.go @@ -25,6 +25,7 @@ type ( appendCode []string // Code to append at the end. returnVars []string // Return variables to modify. appendSwitch functionSwitches // Switch cases to append. + removeCalls []string // Function calls to remove. } // FunctionOptions configures code generation. @@ -219,6 +220,15 @@ func AppendSwitchCase(condition, switchCase, switchBody string) FunctionOptions } } +// RemoveFuncCall removes function calls with the specified name from within a function. +// The callName can be either a simple function name like "doSomething" or a qualified +// name like "pkg.DoSomething". +func RemoveFuncCall(callName string) FunctionOptions { + return func(c *functionOpts) { + c.removeCalls = append(c.removeCalls, callName) + } +} + // newFunctionOptions creates a new functionOpts with defaults. func newFunctionOptions() functionOpts { return functionOpts{ @@ -230,6 +240,7 @@ func newFunctionOptions() functionOpts { appendTestCase: make([]string, 0), appendCode: make([]string, 0), returnVars: make([]string, 0), + removeCalls: make([]string, 0), } } @@ -635,6 +646,13 @@ func applyFunctionOptions(fileSet *token.FileSet, f *ast.FuncDecl, opts *functio switchesCasesMapCheck = opts.appendSwitch.Map() ) + // Remove function calls if specified. + if len(opts.removeCalls) > 0 { + if err := removeFunctionCalls(f, opts.removeCalls); err != nil { + return err + } + } + // Apply all modifications. var errInspect error ast.Inspect(f, func(n ast.Node) bool { @@ -920,3 +938,146 @@ func ModifyCaller(content, callerExpr string, modifiers func([]string) ([]string return string(result), nil } + +// RemoveFunction removes a function declaration from the file content. +func RemoveFunction(content, funcName string) (string, error) { + // Parse source into AST. + fset := token.NewFileSet() + file, err := parser.ParseFile(fset, "", content, parser.ParseComments) + if err != nil { + return "", errors.Errorf("failed to parse file: %w", err) + } + + cmap := ast.NewCommentMap(fset, file, file.Comments) + + // Find the function to remove. + var found bool + var newDecls []ast.Decl + for _, decl := range file.Decls { + if fd, ok := decl.(*ast.FuncDecl); ok && fd.Name.Name == funcName { + found = true + // Remove comments associated with this function. + delete(cmap, decl) + continue // Skip this declaration to remove it. + } + newDecls = append(newDecls, decl) + } + + if !found { + return "", errors.Errorf("function %q not found", funcName) + } + + // Update file declarations and comments. + file.Decls = newDecls + file.Comments = cmap.Filter(file).Comments() + + return formatNode(fset, file) +} + +// removeFunctionCalls removes all function calls matching the specified names from a function. +func removeFunctionCalls(f *ast.FuncDecl, callNames []string) error { + if f.Body == nil { + return nil + } + + // Create a map for faster lookup. + callMap := make(map[string]bool) + for _, name := range callNames { + callMap[name] = true + } + + // Helper to check if a call expression matches any of the names to remove. + matchesCall := func(callExpr *ast.CallExpr) bool { + switch fun := callExpr.Fun.(type) { + case *ast.Ident: + // Simple function call like doSomething(). + return callMap[fun.Name] + case *ast.SelectorExpr: + // Qualified function call like pkg.DoSomething(). + if ident, ok := fun.X.(*ast.Ident); ok { + qualified := ident.Name + "." + fun.Sel.Name + return callMap[qualified] + } + } + return false + } + + // Filter statements to remove matching function calls. + var filterStmts func([]ast.Stmt) []ast.Stmt + filterStmts = func(stmts []ast.Stmt) []ast.Stmt { + var filtered []ast.Stmt + for _, stmt := range stmts { + keep := true + + // Check if this is an expression statement with a call expression. + if exprStmt, ok := stmt.(*ast.ExprStmt); ok { + if callExpr, ok := exprStmt.X.(*ast.CallExpr); ok { + if matchesCall(callExpr) { + keep = false + } + } + } + + // Recursively handle block statements. + if blockStmt, ok := stmt.(*ast.BlockStmt); ok { + blockStmt.List = filterStmts(blockStmt.List) + } + + // Recursively handle if statements. + if ifStmt, ok := stmt.(*ast.IfStmt); ok { + if ifStmt.Body != nil { + ifStmt.Body.List = filterStmts(ifStmt.Body.List) + } + if ifStmt.Else != nil { + if elseBlock, ok := ifStmt.Else.(*ast.BlockStmt); ok { + elseBlock.List = filterStmts(elseBlock.List) + } + } + } + + // Recursively handle for statements. + if forStmt, ok := stmt.(*ast.ForStmt); ok { + if forStmt.Body != nil { + forStmt.Body.List = filterStmts(forStmt.Body.List) + } + } + + // Recursively handle range statements. + if rangeStmt, ok := stmt.(*ast.RangeStmt); ok { + if rangeStmt.Body != nil { + rangeStmt.Body.List = filterStmts(rangeStmt.Body.List) + } + } + + // Recursively handle switch statements. + if switchStmt, ok := stmt.(*ast.SwitchStmt); ok { + if switchStmt.Body != nil { + for _, caseClause := range switchStmt.Body.List { + if cc, ok := caseClause.(*ast.CaseClause); ok { + cc.Body = filterStmts(cc.Body) + } + } + } + } + + // Recursively handle type switch statements. + if typeSwitchStmt, ok := stmt.(*ast.TypeSwitchStmt); ok { + if typeSwitchStmt.Body != nil { + for _, caseClause := range typeSwitchStmt.Body.List { + if cc, ok := caseClause.(*ast.CaseClause); ok { + cc.Body = filterStmts(cc.Body) + } + } + } + } + + if keep { + filtered = append(filtered, stmt) + } + } + return filtered + } + + f.Body.List = filterStmts(f.Body.List) + return nil +} diff --git a/ignite/pkg/xast/function_test.go b/ignite/pkg/xast/function_test.go index 3a3cf6aa3d..43105f9314 100644 --- a/ignite/pkg/xast/function_test.go +++ b/ignite/pkg/xast/function_test.go @@ -1177,3 +1177,223 @@ func main() { }) } } + +func TestRemoveFunction(t *testing.T) { + tests := []struct { + name string + content string + funcName string + expected string + expectError bool + }{ + { + name: "remove a simple function", + content: `package main + +func main() { + println("hello") +} + +func anotherFunction() { + println("another") +} + +func thirdFunction() { + println("third") +} +`, + funcName: "anotherFunction", + expected: `package main + +func main() { + println("hello") +} + +func thirdFunction() { + println("third") +}`, + }, + { + name: "remove first function", + content: `package main + +func first() { + println("first") +} + +func second() { + println("second") +} +`, + funcName: "first", + expected: `package main + +func second() { + println("second") +}`, + }, + { + name: "remove last function", + content: `package main + +func first() { + println("first") +} + +func second() { + println("second") +} +`, + funcName: "second", + expected: `package main + +func first() { + println("first") +}`, + }, + { + name: "remove function with comments", + content: `package main + +// main is the entry point +func main() { + println("main") +} + +// helperFunc does something +func helperFunc() { + println("helper") +} +`, + funcName: "helperFunc", + expected: `package main + +// main is the entry point +func main() { + println("main") +}`, + }, + { + name: "function not found", + content: `package main + +func main() { + println("hello") +} +`, + funcName: "notFound", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := RemoveFunction(tt.content, tt.funcName) + + if tt.expectError { + require.Error(t, err) + return + } + + require.NoError(t, err) + require.Equal(t, tt.expected, result) + }) + } +} + +func TestRemoveFuncCall(t *testing.T) { + tests := []struct { + name string + content string + funcName string + callName string + expected string + }{ + { + name: "remove a function call", + content: `package main + +func main() { + fmt.Println("before") + doSomething() + fmt.Println("after") +} +`, + funcName: "main", + callName: "doSomething", + expected: `package main + +func main() { + fmt.Println("before") + + fmt.Println("after") +}`, + }, + { + name: "remove qualified function call", + content: `package main + +func main() { + fmt.Println("hello") + pkg.DoSomething() + fmt.Println("world") +} +`, + funcName: "main", + callName: "pkg.DoSomething", + expected: `package main + +func main() { + fmt.Println("hello") + + fmt.Println("world") +}`, + }, + { + name: "remove multiple calls to same function", + content: `package main + +func main() { + doSomething() + fmt.Println("middle") + doSomething() +} +`, + funcName: "main", + callName: "doSomething", + expected: `package main + +func main() { + + fmt.Println("middle") + +}`, + }, + { + name: "remove call with arguments", + content: `package main + +func process() { + validate(arg1, arg2) + execute() +} +`, + funcName: "process", + callName: "validate", + expected: `package main + +func process() { + + execute() +}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := ModifyFunction(tt.content, tt.funcName, RemoveFuncCall(tt.callName)) + require.NoError(t, err) + require.Equal(t, tt.expected, result) + }) + } +}