Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### Features

- [#4790](https://github.com/ignite/cli/pull/4790) Remove global vars and struct placeholders.
- [#4797](https://github.com/ignite/cli/pull/4797) Add `xast` function to append new switch clause.

## [`v29.3.1`](https://github.com/ignite/cli/releases/tag/v29.3.1)

Expand Down
184 changes: 148 additions & 36 deletions ignite/pkg/xast/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,30 @@ import (
type (
// functionOpts represent the options for functions.
functionOpts struct {
funcName string // Name of the function to modify.
newParams []functionParam // Parameters to add to the function.
body string // New function body content.
newLines []functionLine // Lines to insert at specific positions.
insideCall functionCalls // Function calls to modify.
insideStruct functionStructs // Struct literals to modify.
appendTestCase []string // Test cases to append.
appendCode []string // Code to append at the end.
returnVars []string // Return variables to modify.
newParams []functionParam // Parameters to add to the function.
body string // New function body content.
newLines []functionLine // Lines to insert at specific positions.
insideCall functionCalls // Function calls to modify.
insideStruct functionStructs // Struct literals to modify.
appendTestCase []string // Test cases to append.
appendCode []string // Code to append at the end.
returnVars []string // Return variables to modify.
appendSwitch functionSwitches // Switch cases to append.
}

// FunctionOptions configures code generation.
FunctionOptions func(*functionOpts)

// functionStruct represents a struct literal to modify.
functionSwitch struct {
condition string // Condition to find.
switchCase string // Switch case to insert.
switchBody string // Code to insert.
}

functionSwitches []functionSwitch
functionSwitchesMap map[string]functionSwitches

// functionStruct represents a struct literal to modify.
functionStruct struct {
name string // Name of the struct type.
Expand Down Expand Up @@ -83,6 +93,19 @@ func (s functionStructs) Map() functionStructsMap {
return structMap
}

// Map converts a slice of functionStructs to a map keyed by struct name.
func (s functionSwitches) Map() functionSwitchesMap {
switchesMap := make(functionSwitchesMap)
for _, c := range s {
switches, ok := switchesMap[c.condition]
if !ok {
switches = make(functionSwitches, 0)
}
switchesMap[c.condition] = append(switches, c)
}
return switchesMap
}

// Map converts a slice of functionCalls to a map keyed by function name.
func (c functionCalls) Map() functionCallsMap {
callMap := make(functionCallsMap)
Expand Down Expand Up @@ -173,10 +196,20 @@ func NewFuncReturn(returnVars ...string) FunctionOptions {
}
}

// AppendSwitchCase inserts a new case with the code at a specific switch condition statement.
func AppendSwitchCase(condition, switchCase, switchBody string) FunctionOptions {
return func(c *functionOpts) {
c.appendSwitch = append(c.appendSwitch, functionSwitch{
condition: condition,
switchCase: switchCase,
switchBody: switchBody,
})
}
}

// newFunctionOptions creates a new functionOpts with defaults.
func newFunctionOptions(funcName string) functionOpts {
func newFunctionOptions() functionOpts {
return functionOpts{
funcName: funcName,
newParams: make([]functionParam, 0),
body: "",
newLines: make([]functionLine, 0),
Expand All @@ -191,7 +224,7 @@ func newFunctionOptions(funcName string) functionOpts {
// ModifyFunction modifies a function in Go source code using functional options.
func ModifyFunction(content string, funcName string, functions ...FunctionOptions) (string, error) {
// Collect all function options.
opts := newFunctionOptions(funcName)
opts := newFunctionOptions()
for _, fn := range functions {
fn(&opts)
}
Expand Down Expand Up @@ -256,6 +289,55 @@ func modifyReturnVars(fileSet *token.FileSet, returnVars []string) ([]ast.Expr,
return stmts, nil
}

// appendSwitchCase appends a new case to a switch statement.
func appendSwitchCase(fileSet *token.FileSet, stmt ast.Node, fs functionSwitches) error {
for _, f := range fs {
// Parse the new case code
newRetExpr, err := parser.ParseExprFrom(fileSet, "", []byte(f.switchCase), parser.ParseComments)
if err != nil {
return err
}

bodyStmt, err := codeToBlockStmt(fileSet, f.switchBody)
if err != nil {
return err
}

// Create a new case clause
newCase := &ast.CaseClause{
List: []ast.Expr{newRetExpr},
Body: bodyStmt.List,
Case: token.NoPos, // Keep first item aligned with case keyword
Colon: token.NoPos, // Keep colon aligned with case keyword
}

// Handle different types of switch statements
switch statement := stmt.(type) {
case *ast.TypeSwitchStmt:
statement.Body.List = appendCaseToList(statement.Body.List, newCase)
case *ast.SwitchStmt:
statement.Body.List = appendCaseToList(statement.Body.List, newCase)
default:
return errors.Errorf("unsupported switch statement type: %T", stmt)
}
}
return nil
}

// appendCaseToList handles inserting a case clause into a list of statements,
// placing it before any default case if one exists.
func appendCaseToList(list []ast.Stmt, newCase *ast.CaseClause) []ast.Stmt {
if len(list) > 0 {
lastCase, isDefault := list[len(list)-1].(*ast.CaseClause)
if isDefault && len(lastCase.List) == 0 {
// Insert before default.
return append(list[:len(list)-1], newCase, list[len(list)-1])
}
}

return append(list, newCase)
}

// addParams adds new parameters to a function declaration.
func addParams(funcDecl *ast.FuncDecl, newParams []functionParam) error {
for _, p := range newParams {
Expand Down Expand Up @@ -371,7 +453,7 @@ func addTestCase(fSet *token.FileSet, funcDecl *ast.FuncDecl, testCase []string)
// structToBlockStmt parses struct literal code into AST expression.
func structToBlockStmt(fSet *token.FileSet, code string) (ast.Expr, error) {
newFuncContent := toStruct(code)
newContent, err := parser.ParseExprFrom(fSet, "temp.go", newFuncContent, parser.AllErrors)
newContent, err := parser.ParseExprFrom(fSet, "", newFuncContent, parser.AllErrors)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -498,7 +580,8 @@ func formatNode(fileSet *token.FileSet, n ast.Node) (string, error) {
return "", err
}

return buf.String(), nil
node := strings.TrimSpace(buf.String())
return node, nil
}

// applyFunctionOptions applies all modifications to a function.
Expand Down Expand Up @@ -526,20 +609,19 @@ func applyFunctionOptions(fileSet *token.FileSet, f *ast.FuncDecl, opts *functio

// Create maps for tracking modifications.
var (
callMap = opts.insideCall.Map()
callMapCheck = opts.insideCall.Map()
structMap = opts.insideStruct.Map()
structMapCheck = opts.insideStruct.Map()
callMap = opts.insideCall.Map()
callMapCheck = opts.insideCall.Map()
structMap = opts.insideStruct.Map()
structMapCheck = opts.insideStruct.Map()
switchesCasesMap = opts.appendSwitch.Map()
switchesCasesMapCheck = opts.appendSwitch.Map()
)

// Apply all modifications.
var (
found bool
errInspect error
)
var errInspect error
ast.Inspect(f, func(n ast.Node) bool {
funcDecl, ok := n.(*ast.FuncDecl)
if !ok || funcDecl.Name.Name != opts.funcName {
if !ok {
return true
}

Expand Down Expand Up @@ -567,6 +649,39 @@ func applyFunctionOptions(fileSet *token.FileSet, f *ast.FuncDecl, opts *functio
return false
}

for _, bodyList := range funcDecl.Body.List {
var stmt ast.Stmt
var buf bytes.Buffer
switch expr := bodyList.(type) {
case *ast.TypeSwitchStmt:
stmt = expr
if err := format.Node(&buf, fileSet, expr.Assign); err != nil {
errInspect = err
return false
}
case *ast.SwitchStmt:
stmt = expr
if err := format.Node(&buf, fileSet, expr.Tag); err != nil {
errInspect = err
return false
}
default:
continue
}

switchCase, ok := switchesCasesMap[buf.String()]
if !ok {
continue
}

if err := appendSwitchCase(fileSet, stmt, switchCase); err != nil {
errInspect = err
return false
}

delete(switchesCasesMapCheck, buf.String())
}

// Modify function calls and struct literals.
ast.Inspect(funcDecl, func(n ast.Node) bool {
switch expr := n.(type) {
Expand Down Expand Up @@ -610,31 +725,28 @@ func applyFunctionOptions(fileSet *token.FileSet, f *ast.FuncDecl, opts *functio
return false
}

// Verify all modifications were applied.
if len(callMapCheck) > 0 {
errInspect = errors.Errorf("function calls not found: %v", callMapCheck)
return false
}
if len(structMapCheck) > 0 {
errInspect = errors.Errorf("function structs not found: %v", structMapCheck)
return false
}

// Add test cases.
if err := addTestCase(fileSet, funcDecl, opts.appendTestCase); err != nil {
errInspect = err
return false
}

found = true
return false
})

if errInspect != nil {
return errInspect
}
if !found {
return errors.Errorf("function %s not found in file content", opts.funcName)

// Verify all modifications were applied.
if len(callMapCheck) > 0 {
return errors.Errorf("function calls not found: %v", callMapCheck)
}
if len(structMapCheck) > 0 {
return errors.Errorf("function structs not found: %v", structMapCheck)
}
if len(switchesCasesMapCheck) > 0 {
return errors.Errorf("function switch not found: %v", switchesCasesMapCheck)
}

return nil
Expand Down
Loading
Loading