diff --git a/client/assets/settings.go b/client/assets/settings.go index cb9fc9b2..82df01e2 100644 --- a/client/assets/settings.go +++ b/client/assets/settings.go @@ -1,11 +1,8 @@ package assets import ( - "encoding/json" "github.com/chainreactors/IoM-go/proto/client/clientpb" - "github.com/chainreactors/malice-network/helper/utils/configutil" - "io/ioutil" - "path/filepath" + "github.com/gookit/config/v2" ) //var ( @@ -20,10 +17,24 @@ type Settings struct { LocalRPCEnable bool `yaml:"localrpc_enable" config:"localrpc_enable" default:"false"` LocalRPCAddr string `yaml:"localrpc_addr" config:"localrpc_addr" default:"127.0.0.1:15004"` Github *GithubSetting `yaml:"github" config:"github"` + AI *AISettings `yaml:"ai" config:"ai"` //VtApiKey string `yaml:"vt_api_key" config:"vt_api_key" default:""` } +// AISettings holds configuration for AI assistant integration +type AISettings struct { + Enable bool `yaml:"enable" config:"enable" default:"false"` + Provider string `yaml:"provider" config:"provider" default:"openai"` // openai, claude + APIKey string `yaml:"api_key" config:"api_key" default:""` + Endpoint string `yaml:"endpoint" config:"endpoint" default:"https://api.openai.com/v1"` + Model string `yaml:"model" config:"model" default:"gpt-4"` + MaxTokens int `yaml:"max_tokens" config:"max_tokens" default:"1024"` + Timeout int `yaml:"timeout" config:"timeout" default:"30"` + HistorySize int `yaml:"history_size" config:"history_size" default:"20"` + OpsecCheck bool `yaml:"opsec_check" config:"opsec_check" default:"false"` // Enable AI OPSEC risk assessment +} + type GithubSetting struct { Repo string `yaml:"repo" config:"repo" default:""` Owner string `yaml:"owner" config:"owner" default:""` @@ -44,17 +55,24 @@ func (github *GithubSetting) ToProtobuf() *clientpb.GithubActionBuildConfig { } func LoadSettings() (*Settings, error) { - rootDir, _ := filepath.Abs(GetRootAppDir()) - //data, err := os.ReadFile(filepath.Join(rootDir, settingsFileName)) - //if err != nil { - // return defaultSettings(), err - //} - settings := defaultSettings() - err := configutil.LoadConfig(filepath.Join(rootDir, maliceProfile), settings) + setting, err := GetSetting() + if err == nil && setting != nil { + return setting, nil + } + + _, loadErr := LoadProfile() + if loadErr != nil { + return defaultSettings(), loadErr + } + + setting, err = GetSetting() if err != nil { return defaultSettings(), err } - return settings, nil + if setting == nil { + return defaultSettings(), nil + } + return setting, nil } func defaultSettings() *Settings { @@ -68,16 +86,67 @@ func defaultSettings() *Settings { } } +// setConfigs sets multiple config key-value pairs, returning the first error encountered. +func setConfigs(kvs [][2]interface{}) error { + for _, kv := range kvs { + if err := config.Set(kv[0].(string), kv[1]); err != nil { + return err + } + } + return nil +} + // SaveSettings - Save the current settings to disk func SaveSettings(settings *Settings) error { - rootDir, _ := filepath.Abs(GetRootAppDir()) if settings == nil { settings = defaultSettings() } - data, err := json.MarshalIndent(settings, "", " ") - if err != nil { + + // Ensure profile is loaded so we don't overwrite unrelated config sections. + if _, err := LoadProfile(); err != nil { + return err + } + + // Top-level settings + if err := setConfigs([][2]interface{}{ + {"settings.max_server_log_size", settings.MaxServerLogSize}, + {"settings.opsec_threshold", settings.OpsecThreshold}, + {"settings.mcp_enable", settings.McpEnable}, + {"settings.mcp_addr", settings.McpAddr}, + {"settings.localrpc_enable", settings.LocalRPCEnable}, + {"settings.localrpc_addr", settings.LocalRPCAddr}, + }); err != nil { return err } - err = ioutil.WriteFile(filepath.Join(rootDir, maliceProfile), data, 0600) - return err + + // Github settings + if settings.Github != nil { + if err := setConfigs([][2]interface{}{ + {"settings.github.repo", settings.Github.Repo}, + {"settings.github.owner", settings.Github.Owner}, + {"settings.github.token", settings.Github.Token}, + {"settings.github.workflow", settings.Github.Workflow}, + }); err != nil { + return err + } + } + + // AI settings + if settings.AI != nil { + if err := setConfigs([][2]interface{}{ + {"settings.ai.enable", settings.AI.Enable}, + {"settings.ai.provider", settings.AI.Provider}, + {"settings.ai.api_key", settings.AI.APIKey}, + {"settings.ai.endpoint", settings.AI.Endpoint}, + {"settings.ai.model", settings.AI.Model}, + {"settings.ai.max_tokens", settings.AI.MaxTokens}, + {"settings.ai.timeout", settings.AI.Timeout}, + {"settings.ai.history_size", settings.AI.HistorySize}, + {"settings.ai.opsec_check", settings.AI.OpsecCheck}, + }); err != nil { + return err + } + } + + return nil } diff --git a/client/cmd/cli/root.go b/client/cmd/cli/root.go index 86e34aec..21af4d8a 100644 --- a/client/cmd/cli/root.go +++ b/client/cmd/cli/root.go @@ -20,14 +20,18 @@ func rootCmd(con *core.Console) (*cobra.Command, error) { } cmd.TraverseChildren = true - // 添加 --mcp flag + // Add --mcp flag cmd.PersistentFlags().String("mcp", "", "enable MCP server with address (e.g., 127.0.0.1:5005)") - // 添加 --rpc flag + // Add --rpc flag cmd.PersistentFlags().String("rpc", "", "enable local gRPC server with address (e.g., 127.0.0.1:15004)") + // Add global --wizard flag + command.RegisterWizardFlag(cmd) bind := command.MakeBind(cmd, con, "golang") command.BindCommonCommands(bind) - cmd.PersistentPreRunE, cmd.PersistentPostRunE = command.ConsoleRunnerCmd(con, cmd) + // Wrap PersistentPreRunE to support wizard mode + originalPre, originalPost := command.ConsoleRunnerCmd(con, cmd) + cmd.PersistentPreRunE, cmd.PersistentPostRunE = command.WrapWithWizardSupport(con, originalPre, originalPost) cmd.AddCommand(command.ImplantCmd(con)) carapace.Gen(cmd) diff --git a/client/cmd/genhelp/gen_help.go b/client/cmd/genhelp/gen_help.go index 94ea620d..cd3f3215 100644 --- a/client/cmd/genhelp/gen_help.go +++ b/client/cmd/genhelp/gen_help.go @@ -1,7 +1,13 @@ package main import ( + "bytes" "fmt" + "io" + "os" + "sort" + "strings" + "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/IoM-go/proto/services/clientrpc" "github.com/chainreactors/malice-network/client/assets" @@ -42,8 +48,6 @@ import ( "github.com/gookit/config/v2" "github.com/gookit/config/v2/yaml" "github.com/spf13/cobra" - "io" - "os" ) func init() { diff --git a/client/command/ai/analyze.go b/client/command/ai/analyze.go new file mode 100644 index 00000000..bcbb8570 --- /dev/null +++ b/client/command/ai/analyze.go @@ -0,0 +1,135 @@ +package ai + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/chainreactors/malice-network/client/assets" + "github.com/chainreactors/malice-network/client/core" + "github.com/spf13/cobra" +) + +// AnalyzeCmd handles the analyze command - analyzes errors and provides suggestions +func AnalyzeCmd(cmd *cobra.Command, con *core.Console, args []string) error { + settings, err := assets.LoadSettings() + if err != nil { + return fmt.Errorf("failed to load settings: %w", err) + } + + if settings.AI == nil || !settings.AI.Enable { + return fmt.Errorf("AI is not enabled. Use 'ai-config --enable' to enable it") + } + + if settings.AI.APIKey == "" { + return fmt.Errorf("AI API key is not configured. Use 'ai-config --api-key ' to set it") + } + + // Get the error to analyze + var errorText string + if len(args) > 0 { + errorText = strings.Join(args, " ") + } + + if errorText == "" { + return fmt.Errorf("please provide an error message to analyze. Usage: analyze ") + } + + // Get context + historySize := settings.AI.HistorySize + if historySize <= 0 { + historySize = 20 + } + history := con.GetRecentHistory(historySize) + + // Build session context if available + sessionContext := buildSessionContext(con) + + // Build the analysis prompt + prompt := buildAnalysisPrompt(errorText, history, sessionContext) + + aiClient := core.NewAIClient(settings.AI) + + timeout := settings.AI.Timeout + if timeout <= 0 { + timeout = 30 + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + + fmt.Println("\nAnalyzing error...") + fmt.Println() + + // Use streaming for real-time output + response, err := aiClient.AskStream(ctx, prompt, nil, func(chunk string) { + fmt.Print(chunk) + }) + if err != nil { + return fmt.Errorf("AI analysis failed: %w", err) + } + + fmt.Println() + + // Parse command suggestions + commands := core.ParseCommandSuggestions(response) + if len(commands) > 0 { + fmt.Println("\nSuggested commands:") + for i, cmd := range commands { + fmt.Printf(" [%d] %s\n", i+1, cmd.Command) + } + } + + fmt.Println() + return nil +} + +func buildSessionContext(con *core.Console) string { + var sb strings.Builder + + session := con.GetInteractive() + if session != nil { + sb.WriteString(fmt.Sprintf("Current session: %s\n", session.SessionId)) + if session.Os != nil { + sb.WriteString(fmt.Sprintf("OS: %s %s\n", session.Os.Name, session.Os.Arch)) + } + if session.Process != nil { + sb.WriteString(fmt.Sprintf("Process: %s (PID: %d)\n", session.Process.Name, session.Process.Pid)) + sb.WriteString(fmt.Sprintf("User: %s\n", session.Process.Owner)) + } + } else { + sb.WriteString("No active session\n") + } + + return sb.String() +} + +func buildAnalysisPrompt(errorText string, history []string, sessionContext string) string { + var sb strings.Builder + + sb.WriteString("Analyze the following error and provide:\n") + sb.WriteString("1. Possible causes of the error\n") + sb.WriteString("2. Suggested solutions or workarounds\n") + sb.WriteString("3. Alternative commands that might work\n\n") + + sb.WriteString("Error message:\n") + sb.WriteString(errorText) + sb.WriteString("\n\n") + + if sessionContext != "" { + sb.WriteString("Session context:\n") + sb.WriteString(sessionContext) + sb.WriteString("\n") + } + + if len(history) > 0 { + sb.WriteString("Recent command history:\n") + for _, cmd := range history { + sb.WriteString(fmt.Sprintf("- %s\n", cmd)) + } + } + + sb.WriteString("\nProvide a concise analysis. Wrap any command suggestions in backticks like `command`.") + + return sb.String() +} diff --git a/client/command/ai/ask.go b/client/command/ai/ask.go new file mode 100644 index 00000000..e2188948 --- /dev/null +++ b/client/command/ai/ask.go @@ -0,0 +1,80 @@ +package ai + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/chainreactors/malice-network/client/assets" + "github.com/chainreactors/malice-network/client/core" + "github.com/spf13/cobra" +) + +// AskCmd handles the ask command +func AskCmd(cmd *cobra.Command, con *core.Console, args []string) error { + question := strings.Join(args, " ") + if question == "" { + return fmt.Errorf("please provide a question") + } + + // Load settings + settings, err := assets.LoadSettings() + if err != nil { + return fmt.Errorf("failed to load settings: %w", err) + } + + if settings.AI == nil || !settings.AI.Enable { + return fmt.Errorf("AI is not enabled. Use 'ai-config --enable --api-key ' to enable it") + } + + if settings.AI.APIKey == "" { + return fmt.Errorf("AI API key is not configured. Use 'ai-config --api-key ' to set it") + } + + // Get history settings + historySize, _ := cmd.Flags().GetInt("history") + noHistory, _ := cmd.Flags().GetBool("no-history") + + var history []string + if !noHistory { + history = con.GetRecentHistory(historySize) + } + + // Create AI client + aiClient := core.NewAIClient(settings.AI) + + // Create context with timeout + timeout := settings.AI.Timeout + if timeout <= 0 { + timeout = 30 + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + + fmt.Println("Thinking...") + + // Ask the AI + response, err := aiClient.Ask(ctx, question, history) + if err != nil { + return fmt.Errorf("AI error: %w", err) + } + + // Parse command suggestions + commands := core.ParseCommandSuggestions(response) + + // Display response + fmt.Printf("\n%s\n", response) + + // If there are command suggestions, list them + if len(commands) > 0 { + fmt.Println("\nSuggested commands:") + for i, cmd := range commands { + fmt.Printf(" [%d] %s\n", i+1, cmd.Command) + } + } + + fmt.Println() + + return nil +} diff --git a/client/command/ai/commands.go b/client/command/ai/commands.go new file mode 100644 index 00000000..c2f172e8 --- /dev/null +++ b/client/command/ai/commands.go @@ -0,0 +1,115 @@ +package ai + +import ( + "github.com/carapace-sh/carapace" + "github.com/chainreactors/malice-network/client/command/common" + "github.com/chainreactors/malice-network/client/core" + "github.com/spf13/cobra" +) + +// Commands returns all AI-related commands +func Commands(con *core.Console) []*cobra.Command { + aiConfigCmd := &cobra.Command{ + Use: "ai-config", + Short: "Configure AI assistant settings", + Long: "Configure the AI assistant with your preferred provider (OpenAI or Claude), API key, model, and other settings.", + RunE: func(cmd *cobra.Command, args []string) error { + return AIConfigCmd(cmd, con) + }, + Annotations: map[string]string{ + "static": "true", + }, + Example: `~~~ +// Enable AI with OpenAI +ai-config --enable --provider openai --api-key "sk-xxx" --model gpt-4 + +// Enable AI with Claude +ai-config --enable --provider claude --api-key "sk-ant-xxx" --endpoint "https://api.anthropic.com/v1" --model claude-3-opus-20240229 + +// Show current configuration +ai-config --show + +// Disable AI +ai-config --disable +~~~`, + } + + aiConfigCmd.Flags().Bool("enable", false, "Enable AI assistant") + aiConfigCmd.Flags().Bool("disable", false, "Disable AI assistant") + aiConfigCmd.Flags().Bool("show", false, "Show current AI configuration") + aiConfigCmd.Flags().String("provider", "", "AI provider: openai or claude") + aiConfigCmd.Flags().String("api-key", "", "API key for the AI provider") + aiConfigCmd.Flags().String("endpoint", "", "API endpoint URL") + aiConfigCmd.Flags().String("model", "", "Model name (e.g., gpt-4, claude-3-opus-20240229)") + aiConfigCmd.Flags().Int("max-tokens", 0, "Maximum tokens in response") + aiConfigCmd.Flags().Int("timeout", 0, "Request timeout in seconds") + aiConfigCmd.Flags().Int("history-size", 0, "Number of history lines to include as context") + aiConfigCmd.Flags().Bool("opsec-check", false, "Enable AI OPSEC risk assessment for high-risk commands") + + askCmd := &cobra.Command{ + Use: "ask [question]", + Short: "Ask the AI assistant a question", + Long: "Ask the AI assistant a question with command history context. This is equivalent to using '? ' syntax.", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return AskCmd(cmd, con, args) + }, + Annotations: map[string]string{ + "static": "true", + }, + Example: `~~~ +// Ask about commands +ask how do I list all sessions + +// Ask about current target +ask what commands can I run on this target + +// Ask with no history context +ask --no-history how to download a file +~~~`, + } + + askCmd.Flags().Int("history", 20, "Number of history lines to include as context") + askCmd.Flags().Bool("no-history", false, "Don't include command history in context") + + questionCmd := &cobra.Command{ + Use: "? [question]", + Short: "Ask the AI assistant (shortcut)", + Long: "Ask the AI assistant a question. This is equivalent to using '? ' syntax or the 'ask' command.", + Args: cobra.MinimumNArgs(1), + Hidden: true, + RunE: func(cmd *cobra.Command, args []string) error { + return AskCmd(cmd, con, args) + }, + Annotations: map[string]string{ + "static": "true", + }, + } + + carapace.Gen(questionCmd).PositionalAnyCompletion(common.AIQuestionCompleter(con)) + + analyzeCmd := &cobra.Command{ + Use: "analyze [error message]", + Short: "AI-powered error analysis and suggestions", + Long: "Analyze an error message using AI and get suggestions for resolution, including possible causes and alternative commands.", + Args: cobra.MinimumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + return AnalyzeCmd(cmd, con, args) + }, + Annotations: map[string]string{ + "static": "true", + }, + Example: `~~~ +// Analyze an error message +analyze Access denied when trying to read file + +// Analyze with more context +analyze "Error: permission denied for /etc/shadow" + +// Analyze a command failure +analyze "getsystem failed: UAC is enabled" +~~~`, + } + + return []*cobra.Command{aiConfigCmd, askCmd, questionCmd, analyzeCmd} +} diff --git a/client/command/ai/config.go b/client/command/ai/config.go new file mode 100644 index 00000000..ba39ea85 --- /dev/null +++ b/client/command/ai/config.go @@ -0,0 +1,160 @@ +package ai + +import ( + "fmt" + "strings" + + "github.com/chainreactors/malice-network/client/assets" + "github.com/chainreactors/malice-network/client/core" + "github.com/spf13/cobra" +) + +// AIConfigCmd handles the ai-config command +func AIConfigCmd(cmd *cobra.Command, con *core.Console) error { + showConfig, _ := cmd.Flags().GetBool("show") + enableAI, _ := cmd.Flags().GetBool("enable") + disableAI, _ := cmd.Flags().GetBool("disable") + + // Load current settings + settings, err := assets.LoadSettings() + if err != nil { + return fmt.Errorf("failed to load settings: %w", err) + } + + // Initialize AI settings if nil + if settings.AI == nil { + settings.AI = &assets.AISettings{ + Enable: false, + Provider: "openai", + Endpoint: "https://api.openai.com/v1", + Model: "gpt-4", + MaxTokens: 1024, + Timeout: 30, + HistorySize: 20, + } + } + + // Show current config + if showConfig { + printAIConfig(settings.AI) + return nil + } + + // If no flags provided, show help + if !enableAI && !disableAI && !cmd.Flags().Changed("provider") && + !cmd.Flags().Changed("api-key") && !cmd.Flags().Changed("endpoint") && + !cmd.Flags().Changed("model") && !cmd.Flags().Changed("max-tokens") && + !cmd.Flags().Changed("timeout") && !cmd.Flags().Changed("history-size") { + printAIConfig(settings.AI) + fmt.Println("\nUse --help to see available options") + return nil + } + + // Update settings based on flags + if enableAI { + settings.AI.Enable = true + } + if disableAI { + settings.AI.Enable = false + } + + if provider, _ := cmd.Flags().GetString("provider"); provider != "" { + provider = strings.ToLower(provider) + if provider == "anthropic" { + provider = "claude" + } + if provider != "openai" && provider != "claude" { + return fmt.Errorf("invalid provider: %s. Must be 'openai' or 'claude'", provider) + } + settings.AI.Provider = provider + + // Set default endpoint based on provider + if !cmd.Flags().Changed("endpoint") { + if provider == "claude" { + settings.AI.Endpoint = "https://api.anthropic.com/v1" + } else { + settings.AI.Endpoint = "https://api.openai.com/v1" + } + } + } + + if apiKey, _ := cmd.Flags().GetString("api-key"); apiKey != "" { + settings.AI.APIKey = apiKey + } + + if endpoint, _ := cmd.Flags().GetString("endpoint"); endpoint != "" { + settings.AI.Endpoint = endpoint + } + + if model, _ := cmd.Flags().GetString("model"); model != "" { + settings.AI.Model = model + } + + if maxTokens, _ := cmd.Flags().GetInt("max-tokens"); maxTokens > 0 { + settings.AI.MaxTokens = maxTokens + } + + if timeout, _ := cmd.Flags().GetInt("timeout"); timeout > 0 { + settings.AI.Timeout = timeout + } + + if historySize, _ := cmd.Flags().GetInt("history-size"); historySize > 0 { + settings.AI.HistorySize = historySize + } + + if cmd.Flags().Changed("opsec-check") { + opsecCheck, _ := cmd.Flags().GetBool("opsec-check") + settings.AI.OpsecCheck = opsecCheck + } + + // Validate configuration if enabling + if settings.AI.Enable && settings.AI.APIKey == "" { + fmt.Println("Warning: AI is enabled but API key is not set. Use --api-key to set it.") + } + + // Save settings + if err := assets.SaveSettings(settings); err != nil { + return fmt.Errorf("failed to save settings: %w", err) + } + + fmt.Println("AI configuration updated successfully") + printAIConfig(settings.AI) + + return nil +} + +func printAIConfig(ai *assets.AISettings) { + fmt.Println("\nAI Configuration:") + fmt.Println("─────────────────────────────────────") + + enabledStr := "No" + if ai.Enable { + enabledStr = "Yes" + } + fmt.Printf(" Enabled: %s\n", enabledStr) + fmt.Printf(" Provider: %s\n", ai.Provider) + fmt.Printf(" Endpoint: %s\n", ai.Endpoint) + fmt.Printf(" Model: %s\n", ai.Model) + + // Mask API key + apiKeyDisplay := "(not set)" + if ai.APIKey != "" { + if len(ai.APIKey) > 8 { + apiKeyDisplay = ai.APIKey[:4] + "..." + ai.APIKey[len(ai.APIKey)-4:] + } else { + apiKeyDisplay = "****" + } + } + fmt.Printf(" API Key: %s\n", apiKeyDisplay) + + fmt.Printf(" Max Tokens: %d\n", ai.MaxTokens) + fmt.Printf(" Timeout: %ds\n", ai.Timeout) + fmt.Printf(" History Size: %d lines\n", ai.HistorySize) + + opsecCheckStr := "No" + if ai.OpsecCheck { + opsecCheckStr = "Yes" + } + fmt.Printf(" OPSEC Check: %s\n", opsecCheckStr) + fmt.Println() +} diff --git a/client/command/build/build-beacon.go b/client/command/build/build-beacon.go index 0602a55b..39104648 100644 --- a/client/command/build/build-beacon.go +++ b/client/command/build/build-beacon.go @@ -95,8 +95,7 @@ func BeaconCmd(cmd *cobra.Command, con *core.Console) error { if err != nil { return err } - executeBuild(con, buildConfig) - return nil + return executeBuild(con, buildConfig) } // prepareBuildConfig 准备标准构建配置 diff --git a/client/command/build/build-module.go b/client/command/build/build-module.go index 458f7e53..3ef9ee8c 100644 --- a/client/command/build/build-module.go +++ b/client/command/build/build-module.go @@ -49,8 +49,10 @@ func ModulesCmd(cmd *cobra.Command, con *core.Console) error { } else { mainProfile.Implant.Modules = strings.Split(modules, ",") } - buildConfig.MaleficConfig, _ = mainProfile.ToYAML() + buildConfig.MaleficConfig, err = mainProfile.ToYAML() + if err != nil { + return err + } - executeBuild(con, buildConfig) - return nil + return executeBuild(con, buildConfig) } diff --git a/client/command/build/build-prelude.go b/client/command/build/build-prelude.go index aead3ae4..072e4734 100644 --- a/client/command/build/build-prelude.go +++ b/client/command/build/build-prelude.go @@ -42,6 +42,5 @@ func PreludeCmd(cmd *cobra.Command, con *core.Console) error { return err } - executeBuild(con, buildConfig) - return nil + return executeBuild(con, buildConfig) } diff --git a/client/command/build/build-pulse.go b/client/command/build/build-pulse.go index 3d9c5d40..5f294a0f 100644 --- a/client/command/build/build-pulse.go +++ b/client/command/build/build-pulse.go @@ -42,9 +42,11 @@ func PulseCmd(cmd *cobra.Command, con *core.Console) error { return fmt.Errorf("failed to parse pulse's build flags: %w", err) } buildConfig.MaleficConfig, err = profile.ToYAML() + if err != nil { + return fmt.Errorf("failed to encode profile: %w", err) + } - executeBuild(con, buildConfig) - return nil + return executeBuild(con, buildConfig) } func parsePulseBuildFlags(cmd *cobra.Command) (*implanttypes.ProfileConfig, error) { diff --git a/client/command/build/build.go b/client/command/build/build.go index 83d572a3..7b8cd8a7 100644 --- a/client/command/build/build.go +++ b/client/command/build/build.go @@ -2,6 +2,7 @@ package build import ( "errors" + "fmt" "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/IoM-go/proto/client/clientpb" "github.com/chainreactors/malice-network/client/command/common" @@ -60,16 +61,14 @@ func parseSourceConfig(cmd *cobra.Command, con *core.Console, buildConfig *clien } // executeBuild 执行构建逻辑 -func executeBuild(con *core.Console, buildConfig *clientpb.BuildConfig) { - go func() { - artifact, err := con.Rpc.Build(con.Context(), buildConfig) - if err != nil { - con.Log.Errorf("Build %s failed: %v\n", buildConfig.BuildType, err) - return - } - con.Log.Infof("Build started: %s (type: %s, target: %s, source: %s)\n", - artifact.Name, artifact.Type, artifact.Target, artifact.Source) - }() +func executeBuild(con *core.Console, buildConfig *clientpb.BuildConfig) error { + artifact, err := con.Rpc.Build(con.Context(), buildConfig) + if err != nil { + return fmt.Errorf("build %s failed: %w", buildConfig.BuildType, err) + } + con.Log.Infof("Build started: %s (type: %s, target: %s, source: %s)\n", + artifact.Name, artifact.Type, artifact.Target, artifact.Source) + return nil } func BindCmd(cmd *cobra.Command, con *core.Console) error { @@ -78,8 +77,7 @@ func BindCmd(cmd *cobra.Command, con *core.Console) error { return err } - executeBuild(con, buildConfig) - return nil + return executeBuild(con, buildConfig) } // parseLibFlag sets buildConfig.Lib based on the --lib flag and validates compatibility with buildType/target. diff --git a/client/command/cert/commands.go b/client/command/cert/commands.go index a8beee16..63eb786b 100644 --- a/client/command/cert/commands.go +++ b/client/command/cert/commands.go @@ -34,6 +34,8 @@ cert import --cert cert_file_path --key key_file_path --ca-cert ca_cert_path } common.BindFlag(importCmd, common.ImportSet) + _ = importCmd.MarkFlagRequired("cert") + _ = importCmd.MarkFlagRequired("key") common.BindFlagCompletions(importCmd, func(comp carapace.ActionMap) { comp["cert"] = carapace.ActionFiles().Usage("path to the cert file") comp["key"] = carapace.ActionFiles().Usage("path to the key file") diff --git a/client/command/client.go b/client/command/client.go index 10aad0ab..a80d059c 100644 --- a/client/command/client.go +++ b/client/command/client.go @@ -1,7 +1,9 @@ package command import ( + "github.com/carapace-sh/carapace" "github.com/chainreactors/IoM-go/consts" + "github.com/chainreactors/malice-network/client/command/ai" "github.com/chainreactors/malice-network/client/command/audit" "github.com/chainreactors/malice-network/client/core" "github.com/reeflective/console" @@ -24,11 +26,14 @@ import ( "github.com/chainreactors/malice-network/client/command/pipeline" "github.com/chainreactors/malice-network/client/command/sessions" "github.com/chainreactors/malice-network/client/command/website" + "github.com/chainreactors/malice-network/client/command/wizard" ) func BindCommonCommands(bind BindFunc) { bind(consts.GenericGroup, - generic.Commands) + generic.Commands, + ai.Commands, + wizard.Commands) bind(consts.ManageGroup, sessions.Commands, @@ -89,6 +94,13 @@ func BindClientsCommands(con *core.Console) console.Commands { }, } + // Register global --wizard flag + RegisterWizardFlag(client) + // Wrap PersistentPreRunE to support wizard mode + client.PersistentPreRunE = func(cmd *cobra.Command, args []string) error { + return HandleWizardFlag(cmd, con) + } + bind := MakeBind(client, con, "golang") BindCommonCommands(bind) @@ -99,6 +111,9 @@ func BindClientsCommands(con *core.Console) console.Commands { client.SetHelpFunc(help.HelpFunc) client.SetHelpCommandGroupID(consts.GenericGroup) + // Register carapace completion for root command (make PersistentFlags visible in subcommands) + carapace.Gen(client) + RegisterClientFunc(con) RegisterImplantFunc(con) return client diff --git a/client/command/common/ai_completer.go b/client/command/common/ai_completer.go new file mode 100644 index 00000000..82d77f99 --- /dev/null +++ b/client/command/common/ai_completer.go @@ -0,0 +1,119 @@ +package common + +import ( + "context" + "strings" + "time" + + "github.com/carapace-sh/carapace" + "github.com/chainreactors/malice-network/client/assets" + "github.com/chainreactors/malice-network/client/core" +) + +// AIQuestionCompleter provides AI-powered completion for questions starting with '?' +// When users type '? ' and press Tab, this completer calls the AI +// and returns suggestions based on the AI's response. +func AIQuestionCompleter(con *core.Console) carapace.Action { + return carapace.ActionCallback(func(c carapace.Context) carapace.Action { + // Build the question from args and current value (works for '?' command or '? ' style). + parts := make([]string, 0, len(c.Args)+1) + parts = append(parts, c.Args...) + if c.Value != "" { + parts = append(parts, c.Value) + } + question := strings.TrimSpace(strings.Join(parts, " ")) + question = strings.TrimSpace(strings.TrimPrefix(question, "?")) + + // If no question yet, show hint + if question == "" { + return carapace.ActionMessage("Type your question after '?', then press Tab for AI suggestions") + } + + // If question is too short, don't call AI + if len(question) < 3 { + return carapace.ActionMessage("Enter a longer question for AI suggestions") + } + + // Load settings + settings, err := assets.GetSetting() + if err != nil || settings == nil || settings.AI == nil || !settings.AI.Enable { + return carapace.ActionMessage("AI not enabled. Use 'ai-config --enable --api-key ' to enable") + } + + if settings.AI.APIKey == "" { + return carapace.ActionMessage("AI API key not configured. Use 'ai-config --api-key '") + } + + // Get command history for context + historySize := 20 + if settings.AI.HistorySize > 0 { + historySize = settings.AI.HistorySize + } + history := con.GetRecentHistory(historySize) + + // Create AI client with a shorter timeout for completion + aiClient := core.NewAIClient(settings.AI) + timeout := 15 // Shorter timeout for completion + if settings.AI.Timeout > 0 && settings.AI.Timeout < timeout { + timeout = settings.AI.Timeout + } + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(timeout)*time.Second) + defer cancel() + + // Ask AI + response, err := aiClient.Ask(ctx, question, history) + if err != nil { + return carapace.ActionMessage("AI Error: " + err.Error()) + } + + // Parse command suggestions from response + commands := core.ParseCommandSuggestions(response) + + if len(commands) == 0 { + // No specific commands found, show a truncated response + lines := strings.Split(response, "\n") + results := make([]string, 0) + shown := 0 + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + if shown >= 5 { // Limit to first 5 non-empty lines + break + } + shown++ + + // Truncate long lines + if len(line) > 80 { + line = line[:77] + "..." + } + results = append(results, line, "") + } + if len(results) > 0 { + return carapace.ActionValuesDescribed(results...).Tag("AI Response") + } + return carapace.ActionMessage("No suggestions from AI") + } + + // Build completion results from commands + results := make([]string, 0, len(commands)*2) + for _, cmd := range commands { + description := cmd.Description + if description == "" { + description = "AI suggested command" + } + results = append(results, cmd.Command, description) + } + + return carapace.ActionValuesDescribed(results...).Tag("AI Suggestions") + }) +} + +// RegisterAICompleter registers the AI question completer for the '?' prefix +// This should be called during command registration +func RegisterAICompleter(con *core.Console) { + // The AI completer is invoked through PreCmdRunLineHooks for '?' prefix + // Tab completion is handled by the carapace integration + // This function can be used for additional registration if needed +} diff --git a/client/command/common/flagset.go b/client/command/common/flagset.go index ba4e3599..1c7f28e9 100644 --- a/client/command/common/flagset.go +++ b/client/command/common/flagset.go @@ -1,6 +1,9 @@ package common import ( + "errors" + "strings" + "github.com/chainreactors/IoM-go/proto/client/clientpb" "github.com/chainreactors/IoM-go/proto/implant/implantpb" "github.com/chainreactors/malice-network/helper/cryptography" @@ -333,29 +336,35 @@ func ParseImportCertFlags(cmd *cobra.Command) (*clientpb.TLS, error) { keyPath, _ := cmd.Flags().GetString("key") caPath, _ := cmd.Flags().GetString("ca-cert") - var err error - var cert, key, ca string - if certPath != "" && keyPath != "" && caPath != "" { - cert, err = cryptography.ProcessPEM(certPath) - if err != nil { - return nil, err - } - key, err = cryptography.ProcessPEM(keyPath) - if err != nil { - return nil, err - } - ca, err = cryptography.ProcessPEM(caPath) - if err != nil { - return nil, err - } + certPath = strings.TrimSpace(certPath) + keyPath = strings.TrimSpace(keyPath) + caPath = strings.TrimSpace(caPath) + + if certPath == "" || keyPath == "" { + return nil, errors.New("cert and key are required") } - return &clientpb.TLS{ + + cert, err := cryptography.ProcessPEM(certPath) + if err != nil { + return nil, err + } + key, err := cryptography.ProcessPEM(keyPath) + if err != nil { + return nil, err + } + + tls := &clientpb.TLS{ Cert: &clientpb.Cert{ Cert: cert, Key: key, }, - Ca: &clientpb.Cert{ - Cert: ca, - }, - }, nil + } + if caPath != "" { + ca, err := cryptography.ProcessPEM(caPath) + if err != nil { + return nil, err + } + tls.Ca = &clientpb.Cert{Cert: ca} + } + return tls, nil } diff --git a/client/command/explorer/commands.go b/client/command/explorer/commands.go index 8e0be913..c3da49f1 100644 --- a/client/command/explorer/commands.go +++ b/client/command/explorer/commands.go @@ -9,8 +9,9 @@ import ( func Commands(con *core.Console) []*cobra.Command { regCommand := &cobra.Command{ - Use: consts.CommandRegExplorer, - Short: "registry explorer", + Use: consts.CommandRegExplorer + " [hive\\path]", + Short: "Interactive registry explorer", + Long: "Explore registry keys and values interactively from a starting hive/path (e.g., HKEY_LOCAL_MACHINE\\SOFTWARE).", Args: cobra.ExactArgs(1), RunE: func(cmd *cobra.Command, args []string) error { return regExplorerCmd(cmd, con) @@ -19,6 +20,10 @@ func Commands(con *core.Console) []*cobra.Command { "depend": consts.ModuleRegListKey, "thirdParty": "true", }, + Example: `~~~ +reg_explorer HKLM\\SOFTWARE +reg_explorer HKEY_CURRENT_USER\\Software +~~~`, } fileCmd := &cobra.Command{ diff --git a/client/command/generic/login.go b/client/command/generic/login.go index 08f18c59..7eab0154 100644 --- a/client/command/generic/login.go +++ b/client/command/generic/login.go @@ -3,6 +3,7 @@ package generic import ( "errors" "fmt" + "strings" "github.com/chainreactors/malice-network/client/assets" "github.com/chainreactors/malice-network/client/core" @@ -28,11 +29,19 @@ func LoginCmd(cmd *cobra.Command, con *core.Console) error { con.RPCAddr = rpcAddr } - if filename := cmd.Flags().Arg(0); filename != "" { - return Login(con, filename) - } else if filename, _ := cmd.Flags().GetString("auth"); filename != "" { + // Prefer explicit --auth flag to avoid misinterpreting subcommand arguments + // (e.g. `wizard build beacon`) as an auth file. + if filename, _ := cmd.Flags().GetString("auth"); filename != "" { return Login(con, filename) } + + // Only check Arg(0) as auth file for root command or login command + // Avoid treating subcommand arguments (e.g., 'beacon' in 'wizard build beacon') as auth file + if cmd.Parent() == nil || cmd.Use == "client" || cmd.Use == "login" { + if filename := cmd.Flags().Arg(0); strings.HasSuffix(filename, ".auth") { + return Login(con, filename) + } + } files, err := assets.GetConfigs() if err != nil { return fmt.Errorf("error retrieving YAML files: %w", err) diff --git a/client/command/privilege/commands.go b/client/command/privilege/commands.go index 6dce3149..76e3e518 100644 --- a/client/command/privilege/commands.go +++ b/client/command/privilege/commands.go @@ -1,6 +1,7 @@ package privilege import ( + "github.com/carapace-sh/carapace" "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/malice-network/client/command/common" "github.com/chainreactors/malice-network/client/core" @@ -10,7 +11,7 @@ import ( func Commands(con *core.Console) []*cobra.Command { runasCmd := &cobra.Command{ - Use: "runas --username [username] --domain [domain] --password [password] --program [program] --args [args] --use-profile --use-env --netonly", + Use: "runas --username [username] --domain [domain] --password [password] --path [path] --args [args] --use-profile --use-env --netonly", Short: "Run a program as another user", RunE: func(cmd *cobra.Command, args []string) error { return RunasCmd(cmd, con) @@ -21,7 +22,7 @@ func Commands(con *core.Console) []*cobra.Command { }, Example: `Run a program as a different user: ~~~ - sys runas --username admin --domain EXAMPLE --password admin123 --program /path/to/program --args "arg1 arg2" --use-profile --use-env + sys runas --username admin --domain EXAMPLE --password admin123 --path /path/to/program --args "arg1 arg2" --use-profile --use-env ~~~`, } @@ -35,6 +36,9 @@ func Commands(con *core.Console) []*cobra.Command { f.Bool("use-env", false, "Use user environment") f.Bool("netonly", false, "Use network credentials only") }) + common.BindFlagCompletions(runasCmd, func(comp carapace.ActionMap) { + comp["path"] = carapace.ActionFiles().Usage("path to the program to execute") + }) privsCmd := &cobra.Command{ Use: "privs", diff --git a/client/command/reg/commands.go b/client/command/reg/commands.go index 298bd4c3..153b903a 100644 --- a/client/command/reg/commands.go +++ b/client/command/reg/commands.go @@ -1,6 +1,7 @@ package reg import ( + "github.com/carapace-sh/carapace" "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/malice-network/client/core" "strings" @@ -74,6 +75,16 @@ func Commands(con *core.Console) []*cobra.Command { f.StringP("type", "t", "REG_SZ", "Value type (REG_SZ, REG_BINARY, REG_DWORD, REG_QWORD)") f.StringP("data", "d", "", "Data to set") }) + common.BindFlagCompletions(regAddCmd, func(comp carapace.ActionMap) { + comp["type"] = carapace.ActionValuesDescribed( + "REG_SZ", "String", + "REG_EXPAND_SZ", "Expandable string", + "REG_MULTI_SZ", "Multi-string", + "REG_BINARY", "Binary data", + "REG_DWORD", "32-bit number", + "REG_QWORD", "64-bit number", + ).Tag("registry value type") + }) regDeleteCmd := &cobra.Command{ Use: consts.SubCommandName(consts.ModuleRegDelete) + " --hive [hive] --path [path] --key [key]", diff --git a/client/command/sys/ps.go b/client/command/sys/ps.go index ef2e47d8..0c4e110a 100644 --- a/client/command/sys/ps.go +++ b/client/command/sys/ps.go @@ -48,7 +48,7 @@ func RegisterPsFunc(con *core.Console) { psSet := ctx.Spite.GetPsResponse() var ps []string for _, p := range psSet.GetProcesses() { - ps = append(ps, fmt.Sprintf("%s:%d:%d:%s:%s:%s:%s:%s", + ps = append(ps, fmt.Sprintf("%s:%d:%d:%s:%s:%s:%s", p.Name, p.Pid, p.Ppid, diff --git a/client/command/wizard/commands.go b/client/command/wizard/commands.go new file mode 100644 index 00000000..9900b0ff --- /dev/null +++ b/client/command/wizard/commands.go @@ -0,0 +1,329 @@ +package wizard + +import ( + "fmt" + "sort" + + "github.com/carapace-sh/carapace" + "github.com/chainreactors/malice-network/client/command/common" + "github.com/chainreactors/malice-network/client/core" + "github.com/chainreactors/malice-network/client/plugin" + wizardfw "github.com/chainreactors/malice-network/client/wizard" + "github.com/spf13/cobra" +) + +// Commands returns the wizard commands +func Commands(con *core.Console) []*cobra.Command { + wizardCmd := &cobra.Command{ + Use: "wizard", + Short: "Interactive wizard system", + Long: "Run interactive wizards for configuration and setup", + } + + listCmd := &cobra.Command{ + Use: "list", + Aliases: []string{"ls"}, + Short: "List available wizards", + RunE: func(cmd *cobra.Command, args []string) error { + return ListWizardsCmd(cmd, con) + }, + Example: `~~~ +wizard list +~~~`, + } + + runCmd := &cobra.Command{ + Use: "run [wizard-name]", + Short: "Run a wizard by name or from a spec file", + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + specFile, _ := cmd.Flags().GetString("file") + if specFile != "" { + return RunWizardFileCmd(cmd, con, specFile) + } + if len(args) != 1 { + return fmt.Errorf("wizard name is required (or use --file)") + } + return RunWizardCmd(cmd, con, args[0]) + }, + Example: `~~~ +wizard run listener_setup +wizard run tcp_pipeline +wizard run profile_create +wizard run --file ./wizards/priv_esc.yaml +~~~`, + } + runCmd.Flags().StringP("file", "f", "", "run wizard from a JSON/YAML spec file (path or embed://...)") + + common.BindFlagCompletions(runCmd, func(comp carapace.ActionMap) { + comp["file"] = carapace.ActionFiles().Usage("wizard spec file (JSON/YAML)") + }) + common.BindArgCompletions(runCmd, nil, carapace.ActionCallback(func(c carapace.Context) carapace.Action { + _ = plugin.GetGlobalMalManager() + + templates := wizardfw.ListTemplates() + results := make([]string, 0, len(templates)*2) + for _, name := range templates { + desc := "" + if wiz, ok := wizardfw.GetTemplate(name); ok && wiz != nil { + desc = wiz.Description + } + results = append(results, name, desc) + } + return carapace.ActionValuesDescribed(results...).Tag("wizard template") + })) + + wizardCmd.AddCommand(listCmd, runCmd) + + // Add category commands (build, pipeline, cert, config) + for _, cat := range wizardfw.Categories { + catCmd := createCategoryCommand(con, cat) + wizardCmd.AddCommand(catCmd) + } + + // Add standalone wizard commands (listener, profile, infra) + for _, sw := range wizardfw.StandaloneWizards { + swCmd := createStandaloneCommand(con, sw) + wizardCmd.AddCommand(swCmd) + } + + return []*cobra.Command{wizardCmd} +} + +// createCategoryCommand creates a command for a wizard category +func createCategoryCommand(con *core.Console, cat wizardfw.WizardCategory) *cobra.Command { + cmd := &cobra.Command{ + Use: cat.Name + " [type]", + Short: cat.Description, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + _ = plugin.GetGlobalMalManager() + + var wizardID string + + if len(args) == 1 { + // Direct type specified: wizard build beacon + typeName := args[0] + for _, w := range cat.Wizards { + if w.ID == typeName { + wizardID = w.FullID + break + } + } + if wizardID == "" { + return fmt.Errorf("unknown %s type: %s", cat.Name, typeName) + } + } else { + // No type specified: show interactive menu + options := make([]wizardfw.SelectOption, len(cat.Wizards)) + for i, w := range cat.Wizards { + options[i] = wizardfw.SelectOption{ + Value: w.FullID, + Label: w.ID, + Description: w.Description, + } + } + + selected, err := wizardfw.RunSelect(fmt.Sprintf("Select %s type", cat.Title), options) + if err != nil { + return err + } + wizardID = selected + } + + wiz, ok := wizardfw.GetTemplate(wizardID) + if !ok { + return fmt.Errorf("wizard '%s' not found", wizardID) + } + + return runWizard(con, wiz) + }, + } + + if len(cat.Wizards) > 0 { + results := make([]string, 0, len(cat.Wizards)*2) + for _, w := range cat.Wizards { + results = append(results, w.ID, w.Description) + } + common.BindArgCompletions(cmd, nil, carapace.ActionValuesDescribed(results...).Tag(cat.Title+" wizard")) + } + + // Add valid types to help text + var types []string + for _, w := range cat.Wizards { + types = append(types, w.ID) + } + if len(types) > 0 { + cmd.Example = fmt.Sprintf("~~~\nwizard %s\nwizard %s %s\n~~~", cat.Name, cat.Name, types[0]) + } else { + cmd.Example = fmt.Sprintf("~~~\nwizard %s\n~~~", cat.Name) + } + + return cmd +} + +// createStandaloneCommand creates a command for a standalone wizard +func createStandaloneCommand(con *core.Console, sw wizardfw.WizardEntry) *cobra.Command { + return &cobra.Command{ + Use: sw.ID, + Short: sw.Description, + Args: cobra.NoArgs, + RunE: func(cmd *cobra.Command, args []string) error { + _ = plugin.GetGlobalMalManager() + + wiz, ok := wizardfw.GetTemplate(sw.FullID) + if !ok { + return fmt.Errorf("wizard '%s' not found", sw.FullID) + } + + return runWizard(con, wiz) + }, + Example: fmt.Sprintf("~~~\nwizard %s\n~~~", sw.ID), + } +} + +// ListWizardsCmd lists all available wizards +func ListWizardsCmd(cmd *cobra.Command, con *core.Console) error { + // Ensure mal plugins are loaded so any resources/wizards specs are registered. + _ = plugin.GetGlobalMalManager() + + con.Log.Infof("Available wizards:\n\n") + + // Show categories + for _, cat := range wizardfw.Categories { + con.Log.Infof(" %s (%s):\n", cat.Title, cat.Name) + for _, w := range cat.Wizards { + con.Log.Infof(" %-12s - %s\n", w.ID, w.Description) + } + con.Log.Infof("\n") + } + + // Show standalone wizards + con.Log.Infof(" Standalone:\n") + for _, sw := range wizardfw.StandaloneWizards { + con.Log.Infof(" %-12s - %s\n", sw.ID, sw.Description) + } + con.Log.Infof("\n") + + // Show plugin wizards (those not in categories) + templates := wizardfw.ListTemplates() + knownIDs := make(map[string]bool) + for _, cat := range wizardfw.Categories { + for _, w := range cat.Wizards { + knownIDs[w.FullID] = true + } + } + for _, sw := range wizardfw.StandaloneWizards { + knownIDs[sw.FullID] = true + } + + var pluginWizards []string + for _, name := range templates { + if !knownIDs[name] { + pluginWizards = append(pluginWizards, name) + } + } + + if len(pluginWizards) > 0 { + sort.Strings(pluginWizards) + con.Log.Infof(" Plugin wizards:\n") + for _, name := range pluginWizards { + wiz, _ := wizardfw.GetTemplate(name) + if wiz != nil { + con.Log.Infof(" %-20s - %s\n", name, wiz.Description) + } else { + con.Log.Infof(" %s\n", name) + } + } + con.Log.Infof("\n") + } + + con.Log.Infof("Usage:\n") + con.Log.Infof(" wizard - Select from category (e.g., wizard build)\n") + con.Log.Infof(" wizard - Run directly (e.g., wizard build beacon)\n") + con.Log.Infof(" wizard - Run standalone wizard (e.g., wizard listener)\n") + con.Log.Infof(" wizard run - Run by full name (e.g., wizard run build_beacon)\n") + + return nil +} + +// RunWizardCmd runs a specific wizard +func RunWizardCmd(cmd *cobra.Command, con *core.Console, name string) error { + // Ensure mal plugins are loaded so any resources/wizards specs are registered. + _ = plugin.GetGlobalMalManager() + + wiz, ok := wizardfw.GetTemplate(name) + if !ok { + return fmt.Errorf("wizard '%s' not found. Use 'wizard list' to see available wizards", name) + } + + return runWizard(con, wiz) +} + +func RunWizardFileCmd(cmd *cobra.Command, con *core.Console, path string) error { + wiz, err := wizardfw.NewWizardFromFile(path) + if err != nil { + return fmt.Errorf("failed to load wizard spec %q: %w", path, err) + } + + return runWizard(con, wiz) +} + +// setupDynamicProviders sets up OptionsProvider for known dynamic fields +func setupDynamicProviders(wiz *wizardfw.Wizard) { + for _, f := range wiz.Fields { + switch f.Name { + case "profile": + f.OptionsProvider = ProfileOptionsProvider() + case "listener_id": + f.OptionsProvider = ListenerOptionsProvider() + case "pipeline", "pipeline_id": + f.OptionsProvider = PipelineOptionsProvider() + case "addresses": + f.OptionsProvider = AddressOptionsProvider() + case "address": + if wiz.ID == "build_pulse" { + f.OptionsProvider = PulseAddressOptionsProvider() + } else { + f.OptionsProvider = AddressOptionsProvider() + } + case "beacon_artifact_id": + f.OptionsProvider = ArtifactOptionsProvider("beacon") + } + } +} + +func runWizard(con *core.Console, wiz *wizardfw.Wizard) error { + // Set dynamic options providers for known fields + setupDynamicProviders(wiz) + + // Prepare dynamic options before running + wiz.PrepareOptions(con) + + runner := wizardfw.NewRunner(wiz) + result, err := runner.RunTwoPhase() + if err != nil { + return fmt.Errorf("wizard failed: %w", err) + } + + con.Log.Infof("\nWizard completed successfully!\n") + con.Log.Infof("Results:\n") + values := result.ToMap() + for _, f := range wiz.Fields { + if v, ok := values[f.Name]; ok { + con.Log.Infof(" %-20s: %v\n", f.Name, v) + } + } + + // Check if there's an executor for this wizard + if executor, ok := GetExecutor(wiz.ID); ok { + con.Log.Infof("\nExecuting wizard actions...\n") + if err := executor(con, result); err != nil { + return fmt.Errorf("wizard execution failed: %w", err) + } + } else { + con.Log.Warnf("\nNo executor registered for wizard '%s'. Results are display-only.\n", wiz.ID) + } + + return nil +} diff --git a/client/command/wizard/executors.go b/client/command/wizard/executors.go new file mode 100644 index 00000000..7267c6d3 --- /dev/null +++ b/client/command/wizard/executors.go @@ -0,0 +1,1420 @@ +package wizard + +import ( + "fmt" + "net" + "net/url" + "os" + "strconv" + "strings" + "sync" + + "github.com/chainreactors/IoM-go/consts" + "github.com/chainreactors/IoM-go/proto/client/clientpb" + "github.com/chainreactors/malice-network/client/core" + wizardfw "github.com/chainreactors/malice-network/client/wizard" + "github.com/chainreactors/malice-network/helper/cryptography" + "github.com/chainreactors/malice-network/helper/implanttypes" + serverbuild "github.com/chainreactors/malice-network/server/build" + "github.com/corpix/uarand" +) + +// ExecutorFunc is a function that executes wizard results +type ExecutorFunc func(con *core.Console, result *wizardfw.WizardResult) error + +var ( + executors = make(map[string]ExecutorFunc) + executorsMu sync.RWMutex +) + +// RegisterExecutor registers an executor for a wizard template +func RegisterExecutor(templateID string, fn ExecutorFunc) { + executorsMu.Lock() + defer executorsMu.Unlock() + executors[templateID] = fn +} + +// GetExecutor returns the executor for a wizard template +func GetExecutor(templateID string) (ExecutorFunc, bool) { + executorsMu.RLock() + defer executorsMu.RUnlock() + fn, ok := executors[templateID] + return fn, ok +} + +// HasExecutor checks if an executor is registered +func HasExecutor(templateID string) bool { + executorsMu.RLock() + defer executorsMu.RUnlock() + _, ok := executors[templateID] + return ok +} + +func init() { + // Register all built-in executors + // Pipeline executors + RegisterExecutor("tcp_pipeline", executeTCPPipeline) + RegisterExecutor("http_pipeline", executeHTTPPipeline) + RegisterExecutor("bind_pipeline", executeBindPipeline) + RegisterExecutor("rem_pipeline", executeREMPipeline) + // Build executors + RegisterExecutor("build_beacon", executeBuildBeacon) + RegisterExecutor("build_pulse", executeBuildPulse) + RegisterExecutor("build_prelude", executeBuildPrelude) + RegisterExecutor("build_module", executeBuildModule) + // Profile executor + RegisterExecutor("profile_create", executeProfileCreate) + // Listener executor + RegisterExecutor("listener_setup", executeListenerSetup) + // Infrastructure executor (composite) + RegisterExecutor("infrastructure_setup", executeInfrastructureSetup) + // Certificate executors + RegisterExecutor("cert_generate", executeCertGenerate) + RegisterExecutor("cert_import", executeCertImport) + // Config executors + RegisterExecutor("github_config", executeGithubConfig) + RegisterExecutor("notify_config", executeNotifyConfig) +} + +// Helper functions for getting values from wizard result + +func derefString(v any) (string, bool) { + switch val := v.(type) { + case string: + return val, true + case *string: + if val != nil { + return *val, true + } + } + return "", false +} + +func getString(result *wizardfw.WizardResult, key string) string { + if v, ok := result.Values[key]; ok { + if s, ok := derefString(v); ok { + return s + } + } + return "" +} + +func getInt(result *wizardfw.WizardResult, key string) int { + v, ok := result.Values[key] + if !ok { + return 0 + } + switch val := v.(type) { + case int: + return val + case int64: + return int(val) + case float64: + return int(val) + default: + if s, ok := derefString(v); ok { + if i, err := strconv.Atoi(s); err == nil { + return i + } + } + } + return 0 +} + +func getUint32(result *wizardfw.WizardResult, key string) uint32 { + return uint32(getInt(result, key)) +} + +func getBool(result *wizardfw.WizardResult, key string) bool { + v, ok := result.Values[key] + if !ok { + return false + } + switch val := v.(type) { + case bool: + return val + case *bool: + return val != nil && *val + default: + if s, ok := derefString(v); ok { + return s == "true" || s == "yes" || s == "1" + } + } + return false +} + +func getFloat64(result *wizardfw.WizardResult, key string) float64 { + v, ok := result.Values[key] + if !ok { + return 0 + } + switch val := v.(type) { + case float64: + return val + case float32: + return float64(val) + case int: + return float64(val) + case int64: + return float64(val) + default: + if s, ok := derefString(v); ok { + if f, err := strconv.ParseFloat(s, 64); err == nil { + return f + } + } + } + return 0 +} + +func splitCommaSeparated(s string) []string { + if strings.TrimSpace(s) == "" { + return nil + } + parts := strings.Split(s, ",") + out := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part != "" { + out = append(out, part) + } + } + if len(out) == 0 { + return nil + } + return out +} + +func normalizeStringSlice(values []string) []string { + if len(values) == 0 { + return nil + } + out := make([]string, 0, len(values)) + for _, value := range values { + value = strings.TrimSpace(value) + if value != "" { + out = append(out, value) + } + } + if len(out) == 0 { + return nil + } + return out +} + +func getStringSlice(result *wizardfw.WizardResult, key string) []string { + v, ok := result.Values[key] + if !ok { + return nil + } + switch val := v.(type) { + case []string: + return normalizeStringSlice(val) + case *[]string: + if val != nil { + return normalizeStringSlice(*val) + } + case []interface{}: + out := make([]string, 0, len(val)) + for _, item := range val { + if s, ok := item.(string); ok { + s = strings.TrimSpace(s) + if s != "" { + out = append(out, s) + } + } + } + return normalizeStringSlice(out) + default: + if s, ok := derefString(v); ok && s != "" { + return splitCommaSeparated(s) + } + } + return nil +} + +// pipelineParams holds common parameters for pipeline creation +type pipelineParams struct { + name string + listenerID string + host string + port uint32 + tls *clientpb.TLS +} + +// checkPortAvailable checks if a TCP port is available for binding +func checkPortAvailable(host string, port uint32) error { + addr := fmt.Sprintf("%s:%d", host, port) + ln, err := net.Listen("tcp", addr) + if err != nil { + return fmt.Errorf("port %d is already in use on %s", port, host) + } + ln.Close() + return nil +} + +// extractPipelineParams extracts common pipeline parameters from wizard result +func extractPipelineParams(result *wizardfw.WizardResult, prefix string) (*pipelineParams, error) { + listenerID := getString(result, "listener_id") + if listenerID == "" { + return nil, fmt.Errorf("listener_id is required") + } + + host := getString(result, "host") + if host == "" { + host = "0.0.0.0" + } + + port := getUint32(result, "port") + if port == 0 { + port = uint32(cryptography.RandomInRange(10240, 65535)) + } + + name := getString(result, "name") + if name == "" { + name = fmt.Sprintf("%s_%s_%d", prefix, listenerID, port) + } + + var tls *clientpb.TLS + if getBool(result, "tls") { + tls = &clientpb.TLS{Enable: true} + } + + return &pipelineParams{ + name: name, + listenerID: listenerID, + host: host, + port: port, + tls: tls, + }, nil +} + +// executeTCPPipeline executes the TCP pipeline wizard +func executeTCPPipeline(con *core.Console, result *wizardfw.WizardResult) error { + p, err := extractPipelineParams(result, "tcp") + if err != nil { + return err + } + // Check if port is available before registration + if err := checkPortAvailable(p.host, p.port); err != nil { + return fmt.Errorf("cannot create TCP pipeline: %w", err) + } + pipeline := &clientpb.Pipeline{ + Tls: p.tls, + Name: p.name, + ListenerId: p.listenerID, + Parser: consts.ImplantMalefic, + Enable: false, + Body: &clientpb.Pipeline_Tcp{Tcp: &clientpb.TCPPipeline{Name: p.name, Host: p.host, Port: p.port}}, + } + return registerAndStartPipeline(con, pipeline, "TCP") +} + +// executeHTTPPipeline executes the HTTP pipeline wizard +func executeHTTPPipeline(con *core.Console, result *wizardfw.WizardResult) error { + p, err := extractPipelineParams(result, "http") + if err != nil { + return err + } + // Check if port is available before registration + if err := checkPortAvailable(p.host, p.port); err != nil { + return fmt.Errorf("cannot create HTTP pipeline: %w", err) + } + pipeline := &clientpb.Pipeline{ + Tls: p.tls, + Name: p.name, + ListenerId: p.listenerID, + Parser: consts.ImplantMalefic, + Enable: false, + Body: &clientpb.Pipeline_Http{Http: &clientpb.HTTPPipeline{Name: p.name, Host: p.host, Port: p.port}}, + } + return registerAndStartPipeline(con, pipeline, "HTTP") +} + +// registerAndStartPipeline handles the common register and start logic for all pipelines +func registerAndStartPipeline(con *core.Console, pipeline *clientpb.Pipeline, pipelineType string) error { + if _, err := con.Rpc.RegisterPipeline(con.Context(), pipeline); err != nil { + return fmt.Errorf("failed to register %s pipeline: %w", pipelineType, err) + } + con.Log.Importantf("%s Pipeline %s registered\n", pipelineType, pipeline.Name) + + if _, err := con.Rpc.StartPipeline(con.Context(), &clientpb.CtrlPipeline{ + Name: pipeline.Name, + ListenerId: pipeline.ListenerId, + Pipeline: pipeline, + }); err != nil { + return fmt.Errorf("failed to start %s pipeline: %w", pipelineType, err) + } + con.Log.Importantf("%s Pipeline %s started successfully\n", pipelineType, pipeline.Name) + return nil +} + +// executeBindPipeline executes the Bind pipeline wizard +func executeBindPipeline(con *core.Console, result *wizardfw.WizardResult) error { + listenerID := getString(result, "listener_id") + if listenerID == "" { + return fmt.Errorf("listener_id is required") + } + + name := fmt.Sprintf("bind_%s", listenerID) + pipeline := &clientpb.Pipeline{ + Name: name, + ListenerId: listenerID, + Parser: consts.ImplantMalefic, + Enable: false, + Body: &clientpb.Pipeline_Bind{Bind: &clientpb.BindPipeline{Name: name}}, + } + return registerAndStartPipeline(con, pipeline, "Bind") +} + +// executeREMPipeline executes the REM pipeline wizard +func executeREMPipeline(con *core.Console, result *wizardfw.WizardResult) error { + listenerID := getString(result, "listener_id") + if listenerID == "" { + return fmt.Errorf("listener_id is required") + } + + name := getString(result, "name") + if name == "" { + name = fmt.Sprintf("rem_%s", listenerID) + } + + console := getString(result, "console") + if console == "" { + console = "tcp://0.0.0.0:19966" + } + + pipeline := &clientpb.Pipeline{ + Name: name, + ListenerId: listenerID, + Parser: consts.ImplantMalefic, + Secure: &clientpb.Secure{Enable: getBool(result, "secure")}, + Enable: false, + Body: &clientpb.Pipeline_Rem{Rem: &clientpb.REM{Name: name, Console: console}}, + } + return registerAndStartPipeline(con, pipeline, "REM") +} + +// executeBuildBeacon executes the build beacon wizard +func executeBuildBeacon(con *core.Console, result *wizardfw.WizardResult) error { + return executeBuild(con, result, consts.CommandBuildBeacon) +} + +// executeBuildPulse executes the build pulse wizard +func executeBuildPulse(con *core.Console, result *wizardfw.WizardResult) error { + return executeBuild(con, result, consts.CommandBuildPulse) +} + +// executeBuildPrelude executes the build prelude wizard +func executeBuildPrelude(con *core.Console, result *wizardfw.WizardResult) error { + return executeBuild(con, result, consts.CommandBuildPrelude) +} + +// executeBuildModule executes the build module wizard +func executeBuildModule(con *core.Console, result *wizardfw.WizardResult) error { + modules := getStringSlice(result, "modules") + thirdModules := getStringSlice(result, "third_modules") + + if len(modules) > 0 && len(thirdModules) > 0 { + return fmt.Errorf("please choose either modules or third_modules, not both") + } + if len(thirdModules) > 0 { + return executeBuild(con, result, consts.CommandBuild3rdModules) + } + return executeBuild(con, result, consts.CommandBuildModules) +} + +// executeBuild is the common build execution logic +func executeBuild(con *core.Console, result *wizardfw.WizardResult, buildType string) error { + target := getString(result, "target") + if target == "" { + return fmt.Errorf("target is required") + } + + // Set source + source := getString(result, "source") + if source == "" { + source = consts.ArtifactFromDocker + } + + var buildConfig *clientpb.BuildConfig + if buildType == consts.CommandBuildPrelude { + autorunZipPath := getString(result, "autorun") + if autorunZipPath == "" { + return fmt.Errorf("autorun is required") + } + zipData, err := os.ReadFile(autorunZipPath) + if err != nil { + return fmt.Errorf("failed to read autorun zip: %w", err) + } + buildConfig, err = serverbuild.ProcessAutorunZipFromBytes(zipData) + if err != nil { + return fmt.Errorf("failed to process autorun zip: %w", err) + } + buildConfig.ProfileName = getString(result, "profile") + buildConfig.Target = target + buildConfig.BuildType = buildType + buildConfig.Lib = getBool(result, "lib") + buildConfig.Source = source + } else { + buildConfig = &clientpb.BuildConfig{ + ProfileName: getString(result, "profile"), + Target: target, + BuildType: buildType, + Lib: getBool(result, "lib"), + Source: source, + } + } + + // Build profile from wizard results if no implant.yaml provided by bundle (e.g., prelude builds without implant.yaml). + if len(buildConfig.MaleficConfig) == 0 { + profile, err := buildProfileFromWizard(con, result, buildType) + if err != nil { + return fmt.Errorf("failed to build profile: %w", err) + } + buildConfig.MaleficConfig, err = profile.ToYAML() + if err != nil { + return fmt.Errorf("failed to encode profile: %w", err) + } + } + + // Check source availability + resp, err := con.Rpc.CheckSource(con.Context(), buildConfig) + if err != nil { + return fmt.Errorf("failed to check source: %w", err) + } + buildConfig.Source = resp.Source + + // Handle artifact ID for pulse builds + if buildType == consts.CommandBuildPulse { + artifactID := getUint32(result, "beacon_artifact_id") + if artifactID > 0 { + buildConfig.ArtifactId = artifactID + } + } + + // Validate lib flag + if err := validateLibFlag(buildConfig); err != nil { + return err + } + + artifact, err := con.Rpc.Build(con.Context(), buildConfig) + if err != nil { + return fmt.Errorf("build %s failed: %w", buildConfig.BuildType, err) + } + con.Log.Infof("Build started: %s (type: %s, target: %s, source: %s)\n", + artifact.Name, artifact.Type, artifact.Target, artifact.Source) + return nil +} + +// buildProfileFromWizard creates a ProfileConfig from wizard results +func buildProfileFromWizard(con *core.Console, result *wizardfw.WizardResult, buildType string) (*implanttypes.ProfileConfig, error) { + profileName := getString(result, "profile") + var profile *implanttypes.ProfileConfig + var err error + + if profileName != "" { + profilePB, err := con.Rpc.GetProfileByName(con.Context(), &clientpb.Profile{Name: profileName}) + if err != nil { + return nil, fmt.Errorf("failed to get profile %q: %w", profileName, err) + } + profile, err = implanttypes.LoadProfile(profilePB.Content) + } else { + profile, err = implanttypes.LoadProfile(consts.DefaultProfile) + } + if err != nil { + return nil, fmt.Errorf("failed to load profile: %w", err) + } + + ensureProfileSections(profile) + + // Set implant mode + if buildType == consts.CommandBuildBeacon || buildType == consts.CommandBuildBind { + profile.Implant.Mod = buildType + } + + // Apply build-specific settings + if buildType == consts.CommandBuildPulse { + if err := applyPulseSettings(result, profile); err != nil { + return nil, err + } + return profile, nil + } + + // Apply basic settings + applyBasicSettings(result, profile) + + // Parse and apply targets + if addresses := getString(result, "addresses"); addresses != "" { + targets, err := parseTargets(addresses) + if err != nil { + return nil, err + } + profile.Basic.Targets = targets + } + + // Apply modules + applyModuleSettings(result, profile) + + // Apply build settings + applyBuildSettings(result, profile) + + return profile, nil +} + +func ensureProfileSections(profile *implanttypes.ProfileConfig) { + if profile.Basic == nil { + profile.Basic = &implanttypes.BasicProfile{} + } + if profile.Pulse == nil { + profile.Pulse = &implanttypes.PulseProfile{} + } + if profile.Implant == nil { + profile.Implant = &implanttypes.ImplantProfile{} + } + if profile.Build == nil { + profile.Build = &implanttypes.BuildProfile{} + } +} + +// applyBasicSettings applies basic profile settings from wizard result +func applyBasicSettings(result *wizardfw.WizardResult, profile *implanttypes.ProfileConfig) { + // String settings + if v := getString(result, "cron"); v != "" { + profile.Basic.Cron = v + } + if v := getString(result, "encryption"); v != "" { + profile.Basic.Encryption = v + } + if v := getString(result, "key"); v != "" { + profile.Basic.Key = v + } + + // Numeric settings + if v := getFloat64(result, "jitter"); v > 0 { + profile.Basic.Jitter = v + } + if v := getInt(result, "init_retry"); v > 0 { + profile.Basic.InitRetry = v + } + if v := getInt(result, "server_retry"); v > 0 { + profile.Basic.ServerRetry = v + } + if v := getInt(result, "global_retry"); v > 0 { + profile.Basic.GlobalRetry = v + } + + // Secure mode + if getBool(result, "secure") { + if profile.Basic.Secure == nil { + profile.Basic.Secure = &implanttypes.SecureProfile{} + } + profile.Basic.Secure.Enable = true + } + + // Proxy settings + proxy, proxyUseEnv := getString(result, "proxy"), getBool(result, "proxy_use_env") + if proxy != "" || proxyUseEnv { + profile.Basic.Proxy = &implanttypes.ProxyProfile{URL: proxy, UseEnvProxy: proxyUseEnv} + } + + // Guardrail settings + applyGuardrailSettings(result, profile) +} + +type pulseAddress struct { + protocol string + target string +} + +func parsePulseAddress(address string) (*pulseAddress, error) { + address = strings.TrimSpace(address) + if address == "" { + return nil, fmt.Errorf("address is required") + } + + if strings.Contains(address, "://") { + u, err := url.Parse(address) + if err != nil { + return nil, fmt.Errorf("invalid address %q: %w", address, err) + } + if u.User != nil || u.RawQuery != "" || u.Fragment != "" || (u.Path != "" && u.Path != "/") { + return nil, fmt.Errorf("invalid address %q: only scheme://host[:port] is supported", address) + } + + scheme := strings.ToLower(strings.TrimSpace(u.Scheme)) + switch scheme { + case "http", "tcp": + case "https": + return nil, fmt.Errorf("pulse build only supports http:// or tcp:// addresses") + default: + return nil, fmt.Errorf("unsupported address scheme %q", u.Scheme) + } + + if strings.Count(u.Host, ":") > 1 && !strings.HasPrefix(u.Host, "[") { + return nil, fmt.Errorf("invalid address %q: IPv6 hosts must be in brackets, e.g. http://[::1]:80", address) + } + + host := strings.TrimSpace(u.Hostname()) + if host == "" { + return nil, fmt.Errorf("invalid address %q: missing host", address) + } + port := strings.TrimSpace(u.Port()) + if port == "" { + if scheme == "tcp" { + port = "5001" + } else { + port = "80" + } + } + if err := validatePort(port); err != nil { + return nil, err + } + + target := net.JoinHostPort(host, port) + if scheme == "tcp" { + return &pulseAddress{protocol: consts.TCPPipeline, target: target}, nil + } + return &pulseAddress{protocol: consts.HTTPPipeline, target: target}, nil + } + + if strings.ContainsAny(address, "/?#") { + return nil, fmt.Errorf("invalid address %q: expected host[:port], http://host[:port], or tcp://host[:port]", address) + } + + target, err := normalizeHostPort(address, "80") + if err != nil { + return nil, err + } + return &pulseAddress{protocol: consts.HTTPPipeline, target: target}, nil +} + +// applyPulseSettings applies pulse profile settings from wizard result +func applyPulseSettings(result *wizardfw.WizardResult, profile *implanttypes.ProfileConfig) error { + if profile.Pulse == nil { + profile.Pulse = &implanttypes.PulseProfile{} + } + + parsed, err := parsePulseAddress(getString(result, "address")) + if err != nil { + return err + } + profile.Pulse.Protocol = parsed.protocol + profile.Pulse.Target = parsed.target + + if profile.Pulse.Protocol == consts.HTTPPipeline { + if profile.Pulse.Http == nil { + profile.Pulse.Http = &implanttypes.HttpProfile{} + } + profile.Pulse.Http.Method = "POST" + profile.Pulse.Http.Version = "1.1" + profile.Pulse.Http.Host = parsed.target + if profile.Pulse.Http.Headers == nil { + profile.Pulse.Http.Headers = map[string]string{} + } + profile.Pulse.Http.Headers["Host"] = parsed.target + } + + if profile.Pulse.Http != nil { + if v := strings.TrimSpace(getString(result, "path")); v != "" { + profile.Pulse.Http.Path = v + } + if v := strings.TrimSpace(getString(result, "user_agent")); v != "" { + if profile.Pulse.Http.Headers == nil { + profile.Pulse.Http.Headers = map[string]string{} + } + profile.Pulse.Http.Headers["User-Agent"] = v + } + } + + if artifactID := getUint32(result, "beacon_artifact_id"); artifactID != 0 { + if profile.Pulse.Flags == nil { + profile.Pulse.Flags = &implanttypes.PulseFlags{} + } + profile.Pulse.Flags.ArtifactID = artifactID + } + + return nil +} + +// applyGuardrailSettings applies guardrail settings from wizard result +func applyGuardrailSettings(result *wizardfw.WizardResult, profile *implanttypes.ProfileConfig) { + ips := splitCommaSeparated(getString(result, "guardrail_ips")) + users := splitCommaSeparated(getString(result, "guardrail_users")) + servers := splitCommaSeparated(getString(result, "guardrail_servers")) + domains := splitCommaSeparated(getString(result, "guardrail_domains")) + if len(ips) == 0 && len(users) == 0 && len(servers) == 0 && len(domains) == 0 { + return + } + + if profile.Basic.Guardrail == nil { + profile.Basic.Guardrail = &implanttypes.GuardrailProfile{} + } + profile.Basic.Guardrail.Enable = true + profile.Basic.Guardrail.RequireAll = true + + if len(ips) > 0 { + profile.Basic.Guardrail.IPAddresses = ips + } + if len(users) > 0 { + profile.Basic.Guardrail.Usernames = users + } + if len(servers) > 0 { + profile.Basic.Guardrail.ServerNames = servers + } + if len(domains) > 0 { + profile.Basic.Guardrail.Domains = domains + } +} + +// addressScheme defines how to parse a URL scheme into a Target +type addressScheme struct { + prefix string + defaultPort string + configure func(host string, target *implanttypes.Target) +} + +var addressSchemes = []addressScheme{ + {"http://", "80", configureHTTP}, + {"https://", "443", configureHTTPS}, + {"tcp+tls://", "5001", configureTCPTLS}, + {"tcp://", "5001", configureTCP}, +} + +func configureHTTP(host string, target *implanttypes.Target) { + target.Http = defaultHTTPProfile() +} + +func configureHTTPS(host string, target *implanttypes.Target) { + target.Http = defaultHTTPProfile() + target.TLS = &implanttypes.TLSProfile{ + Enable: true, + SNI: host, + SkipVerification: true, + } +} + +func configureTCP(host string, target *implanttypes.Target) { + target.TCP = &implanttypes.TCPProfile{} +} + +func configureTCPTLS(host string, target *implanttypes.Target) { + target.TCP = &implanttypes.TCPProfile{} + target.TLS = &implanttypes.TLSProfile{ + Enable: true, + SNI: host, + SkipVerification: true, + } +} + +func defaultHTTPProfile() *implanttypes.HttpProfile { + return &implanttypes.HttpProfile{ + Method: "POST", + Path: "/", + Version: "1.1", + Headers: map[string]string{ + "User-Agent": uarand.GetRandom(), + "Content-Type": "application/octet-stream", + }, + } +} + +// parseTargets parses comma-separated addresses into Target slice +func parseTargets(addresses string) ([]implanttypes.Target, error) { + var targets []implanttypes.Target + for _, raw := range strings.Split(addresses, ",") { + addr := strings.TrimSpace(raw) + if addr == "" { + continue + } + target, err := parseAddress(addr) + if err != nil { + return nil, err + } + targets = append(targets, *target) + } + if len(targets) == 0 { + return nil, fmt.Errorf("no valid targets found in addresses") + } + return targets, nil +} + +// parseAddress parses a single address into a Target +func parseAddress(address string) (*implanttypes.Target, error) { + address = strings.TrimSpace(address) + if address == "" { + return nil, fmt.Errorf("address is empty") + } + + if strings.Contains(address, "://") { + u, err := url.Parse(address) + if err != nil { + return nil, fmt.Errorf("invalid address %q: %w", address, err) + } + if u.User != nil || u.RawQuery != "" || u.Fragment != "" || (u.Path != "" && u.Path != "/") { + return nil, fmt.Errorf("invalid address %q: only scheme://host[:port] is supported", address) + } + if strings.Count(u.Host, ":") > 1 && !strings.HasPrefix(u.Host, "[") { + return nil, fmt.Errorf("invalid address %q: IPv6 hosts must be in brackets, e.g. tcp://[::1]:5001", address) + } + + scheme := strings.ToLower(strings.TrimSpace(u.Scheme)) + + // Find matching scheme config in a single pass + var matched *addressScheme + for i := range addressSchemes { + if strings.TrimSuffix(addressSchemes[i].prefix, "://") == scheme { + matched = &addressSchemes[i] + break + } + } + if matched == nil { + return nil, fmt.Errorf("unsupported address scheme %q", u.Scheme) + } + + host := strings.TrimSpace(u.Hostname()) + if host == "" { + return nil, fmt.Errorf("invalid address %q: missing host", address) + } + port := strings.TrimSpace(u.Port()) + if port == "" { + port = matched.defaultPort + } + if err := validatePort(port); err != nil { + return nil, err + } + + target := &implanttypes.Target{Address: net.JoinHostPort(host, port)} + matched.configure(host, target) + return target, nil + } + + if strings.ContainsAny(address, "/?#") { + return nil, fmt.Errorf("invalid address %q: expected host[:port] or scheme://host[:port]", address) + } + + addr, err := normalizeHostPort(address, "5001") + if err != nil { + return nil, err + } + target := &implanttypes.Target{ + Address: addr, + TCP: &implanttypes.TCPProfile{}, + } + return target, nil +} + +func validatePort(port string) error { + p, err := strconv.Atoi(port) + if err != nil || p < 1 || p > 65535 { + return fmt.Errorf("invalid port: %q", port) + } + return nil +} + +func normalizeHostPort(addr string, defaultPort string) (string, error) { + addr = strings.TrimSpace(addr) + if addr == "" { + return "", fmt.Errorf("address is empty") + } + if strings.ContainsAny(addr, "/?#") { + return "", fmt.Errorf("invalid address %q: expected host[:port]", addr) + } + + port := defaultPort + + switch { + case strings.HasPrefix(addr, "["): + if strings.Contains(addr, "]") && !strings.Contains(addr, "]:") { + addr = addr + ":" + port + } + h, p, splitErr := net.SplitHostPort(addr) + if splitErr != nil { + return "", fmt.Errorf("invalid address %q: %w", addr, splitErr) + } + if strings.TrimSpace(h) == "" { + return "", fmt.Errorf("invalid address %q: missing host", addr) + } + if err := validatePort(p); err != nil { + return "", err + } + return net.JoinHostPort(h, p), nil + + case strings.Count(addr, ":") == 0: + if err := validatePort(port); err != nil { + return "", err + } + return net.JoinHostPort(addr, port), nil + + case strings.Count(addr, ":") == 1: + h, p, splitErr := net.SplitHostPort(addr) + if splitErr != nil { + return "", fmt.Errorf("invalid address %q: %w", addr, splitErr) + } + if strings.TrimSpace(h) == "" { + return "", fmt.Errorf("invalid address %q: missing host", addr) + } + if err := validatePort(p); err != nil { + return "", err + } + return net.JoinHostPort(h, p), nil + + default: + // Bare IPv6 address without brackets + ipPart := addr + if i := strings.LastIndex(ipPart, "%"); i != -1 { + ipPart = ipPart[:i] + } + if net.ParseIP(ipPart) == nil { + return "", fmt.Errorf("invalid IPv6 address %q (use [ipv6]:port)", addr) + } + if err := validatePort(port); err != nil { + return "", err + } + return net.JoinHostPort(addr, port), nil + } +} + +// applyModuleSettings applies module settings from wizard result +func applyModuleSettings(result *wizardfw.WizardResult, profile *implanttypes.ProfileConfig) { + if modules := getStringSlice(result, "modules"); len(modules) > 0 { + profile.Implant.Modules = modules + } + if thirdModules := getStringSlice(result, "third_modules"); len(thirdModules) > 0 { + profile.Implant.ThirdModules = thirdModules + profile.Implant.Enable3rd = true + } +} + +// applyBuildSettings applies build settings from wizard result +func applyBuildSettings(result *wizardfw.WizardResult, profile *implanttypes.ProfileConfig) { + if getBool(result, "ollvm") { + if profile.Build == nil { + profile.Build = &implanttypes.BuildProfile{} + } + profile.Build.OLLVM = &implanttypes.OLLVMProfile{ + Enable: true, BCFObf: true, SplitObf: true, SubObf: true, FCO: true, ConstEnc: true, + } + } + if getBool(result, "anti_sandbox") { + if profile.Implant == nil { + profile.Implant = &implanttypes.ImplantProfile{} + } + if profile.Implant.Anti == nil { + profile.Implant.Anti = &implanttypes.AntiProfile{} + } + profile.Implant.Anti.Sandbox = true + } +} + +// validateLibFlag validates the lib flag based on build type and target +func validateLibFlag(buildConfig *clientpb.BuildConfig) error { + target, ok := consts.GetBuildTarget(buildConfig.Target) + if !ok { + return fmt.Errorf("invalid target: %s", buildConfig.Target) + } + + switch buildConfig.BuildType { + case consts.CommandBuildModules, consts.CommandBuild3rdModules: + if target.OS != consts.Windows { + return fmt.Errorf("modules build only supports Windows targets") + } + buildConfig.Lib = true + case consts.CommandBuildPrelude: + buildConfig.Lib = false + case consts.CommandBuildPulse: + if target.OS != consts.Windows { + return fmt.Errorf("pulse build only supports Windows targets") + } + buildConfig.Lib = false + } + return nil +} + +// executeProfileCreate creates a new profile from wizard results +func executeProfileCreate(con *core.Console, result *wizardfw.WizardResult) error { + name := getString(result, "name") + if name == "" { + return fmt.Errorf("profile name is required") + } + + pipelineID := getString(result, "pipeline") + if pipelineID == "" { + return fmt.Errorf("pipeline is required") + } + + implantType := getString(result, "type") + modules := getStringSlice(result, "modules") + + // Build profile params + var params implanttypes.ProfileParams + if len(modules) > 0 { + params.Modules = strings.Join(modules, ",") + } + + profile := &clientpb.Profile{ + Name: name, + PipelineId: pipelineID, + Params: params.String(), + } + + // Note: implantType is stored in profile content, not params + _ = implantType + + _, err := con.Rpc.NewProfile(con.Context(), profile) + if err != nil { + return fmt.Errorf("failed to create profile: %w", err) + } + + con.Log.Importantf("Profile '%s' created successfully for pipeline '%s'\n", name, pipelineID) + return nil +} + +// executeListenerSetup displays listener setup instructions +// Note: Listeners are typically configured via server config file or listener binary +func executeListenerSetup(con *core.Console, result *wizardfw.WizardResult) error { + name := getString(result, "name") + host := getString(result, "host") + protocol := getString(result, "protocol") + port := getInt(result, "port") + tls := getBool(result, "tls") + + con.Log.Importantf("Listener Configuration:\n") + con.Log.Infof(" Name: %s\n", name) + con.Log.Infof(" Host: %s\n", host) + con.Log.Infof(" Protocol: %s\n", protocol) + con.Log.Infof(" Port: %d\n", port) + con.Log.Infof(" TLS: %v\n", tls) + con.Log.Warnf("\nNote: Listeners must be started separately using the listener binary.\n") + con.Log.Infof("Example: ./listener --config listener.yaml\n") + + return nil +} + +// executeInfrastructureSetup creates listener config, pipeline, and profile +func executeInfrastructureSetup(con *core.Console, result *wizardfw.WizardResult) error { + // Step 1: Display listener configuration (listeners need to be started separately) + listenerName := getString(result, "listener_name") + listenerHost := getString(result, "listener_host") + listenerProtocol := getString(result, "listener_protocol") + listenerPort := getInt(result, "listener_port") + listenerTLS := getBool(result, "listener_tls") + + con.Log.Importantf("=== Infrastructure Setup ===\n\n") + con.Log.Infof("[1/3] Listener Configuration:\n") + con.Log.Infof(" Name: %s\n", listenerName) + con.Log.Infof(" Host: %s\n", listenerHost) + con.Log.Infof(" Protocol: %s\n", listenerProtocol) + con.Log.Infof(" Port: %d\n", listenerPort) + con.Log.Infof(" TLS: %v\n", listenerTLS) + con.Log.Warnf(" Note: Start listener separately with: ./listener --config listener.yaml\n\n") + + // Step 2: Create pipeline + pipelineType := getString(result, "pipeline_type") + pipelineName := getString(result, "pipeline_name") + pipelineHost := getString(result, "pipeline_host") + pipelinePort := getUint32(result, "pipeline_port") + pipelineTLS := getBool(result, "pipeline_tls") + + if pipelineName == "" { + pipelineName = fmt.Sprintf("%s_%s_%d", pipelineType, listenerName, pipelinePort) + } + + con.Log.Infof("[2/3] Creating Pipeline '%s'...\n", pipelineName) + + // Check port availability + if err := checkPortAvailable(pipelineHost, pipelinePort); err != nil { + return fmt.Errorf("cannot create pipeline: %w", err) + } + + var tls *clientpb.TLS + if pipelineTLS { + tls = &clientpb.TLS{Enable: true} + } + + var pipeline *clientpb.Pipeline + if pipelineType == "http" { + pipeline = &clientpb.Pipeline{ + Tls: tls, + Name: pipelineName, + ListenerId: listenerName, + Parser: consts.ImplantMalefic, + Enable: false, + Body: &clientpb.Pipeline_Http{Http: &clientpb.HTTPPipeline{Name: pipelineName, Host: pipelineHost, Port: pipelinePort}}, + } + } else { + pipeline = &clientpb.Pipeline{ + Tls: tls, + Name: pipelineName, + ListenerId: listenerName, + Parser: consts.ImplantMalefic, + Enable: false, + Body: &clientpb.Pipeline_Tcp{Tcp: &clientpb.TCPPipeline{Name: pipelineName, Host: pipelineHost, Port: pipelinePort}}, + } + } + + if _, err := con.Rpc.RegisterPipeline(con.Context(), pipeline); err != nil { + return fmt.Errorf("failed to register pipeline: %w", err) + } + con.Log.Importantf(" Pipeline '%s' registered\n", pipelineName) + + if _, err := con.Rpc.StartPipeline(con.Context(), &clientpb.CtrlPipeline{ + Name: pipeline.Name, + ListenerId: pipeline.ListenerId, + Pipeline: pipeline, + }); err != nil { + return fmt.Errorf("failed to start pipeline: %w", err) + } + con.Log.Importantf(" Pipeline '%s' started\n\n", pipelineName) + + // Step 3: Create profile + profileName := getString(result, "profile_name") + implantType := getString(result, "implant_type") + modules := getStringSlice(result, "modules") + + con.Log.Infof("[3/3] Creating Profile '%s'...\n", profileName) + + var params implanttypes.ProfileParams + if len(modules) > 0 { + params.Modules = strings.Join(modules, ",") + } + + // Note: implantType is stored in profile content, not params + _ = implantType + + profile := &clientpb.Profile{ + Name: profileName, + PipelineId: pipelineName, + Params: params.String(), + } + + if _, err := con.Rpc.NewProfile(con.Context(), profile); err != nil { + return fmt.Errorf("failed to create profile: %w", err) + } + con.Log.Importantf(" Profile '%s' created for pipeline '%s'\n\n", profileName, pipelineName) + + con.Log.Importantf("=== Infrastructure Setup Complete ===\n") + con.Log.Infof("Next steps:\n") + con.Log.Infof(" 1. Start listener: ./listener --config listener.yaml\n") + con.Log.Infof(" 2. Build implant: wizard build beacon\n") + + return nil +} + +// executeCertGenerate generates a self-signed certificate +func executeCertGenerate(con *core.Console, result *wizardfw.WizardResult) error { + cn := getString(result, "cn") + if cn == "" { + return fmt.Errorf("Common Name (CN) is required") + } + + certSubject := &clientpb.CertificateSubject{ + Cn: cn, + O: getString(result, "o"), + C: getString(result, "c"), + L: getString(result, "l"), + Ou: getString(result, "ou"), + St: getString(result, "st"), + Validity: fmt.Sprintf("%d", getInt(result, "validity")), + } + + _, err := con.Rpc.GenerateSelfCert(con.Context(), &clientpb.Pipeline{ + Tls: &clientpb.TLS{ + CertSubject: certSubject, + Acme: false, + }, + }) + if err != nil { + return fmt.Errorf("failed to generate certificate: %w", err) + } + + con.Log.Importantf("Self-signed certificate generated successfully\n") + con.Log.Infof(" CN: %s\n", cn) + if certSubject.O != "" { + con.Log.Infof(" O: %s\n", certSubject.O) + } + if certSubject.Validity != "" && certSubject.Validity != "0" { + con.Log.Infof(" Validity: %s days\n", certSubject.Validity) + } + + return nil +} + +// executeCertImport imports an existing certificate +func executeCertImport(con *core.Console, result *wizardfw.WizardResult) error { + certPath := getString(result, "cert") + keyPath := getString(result, "key") + + if certPath == "" || keyPath == "" { + return fmt.Errorf("certificate and key files are required") + } + + // Read certificate file + certData, err := cryptography.ProcessPEM(certPath) + if err != nil { + return fmt.Errorf("failed to read certificate file: %w", err) + } + + // Read key file + keyData, err := cryptography.ProcessPEM(keyPath) + if err != nil { + return fmt.Errorf("failed to read key file: %w", err) + } + + // Read CA certificate if provided + var caCert *clientpb.Cert + caPath := getString(result, "ca_cert") + if caPath != "" { + caData, err := cryptography.ProcessPEM(caPath) + if err != nil { + return fmt.Errorf("failed to read CA certificate file: %w", err) + } + caCert = &clientpb.Cert{ + Cert: caData, + } + } + + tls := &clientpb.TLS{ + Cert: &clientpb.Cert{ + Cert: certData, + Key: keyData, + }, + Ca: caCert, + } + + _, err = con.Rpc.GenerateSelfCert(con.Context(), &clientpb.Pipeline{ + Tls: tls, + }) + if err != nil { + return fmt.Errorf("failed to import certificate: %w", err) + } + + con.Log.Importantf("Certificate imported successfully\n") + con.Log.Infof(" Certificate: %s\n", certPath) + con.Log.Infof(" Key: %s\n", keyPath) + if caPath != "" { + con.Log.Infof(" CA Cert: %s\n", caPath) + } + + return nil +} + +// executeGithubConfig configures GitHub Actions build +func executeGithubConfig(con *core.Console, result *wizardfw.WizardResult) error { + owner := getString(result, "owner") + repo := getString(result, "repo") + token := getString(result, "token") + + if owner == "" || repo == "" || token == "" { + return fmt.Errorf("owner, repo, and token are required") + } + + workflowFile := getString(result, "workflow_file") + + githubConfig := &clientpb.GithubActionBuildConfig{ + Owner: owner, + Repo: repo, + Token: token, + WorkflowId: workflowFile, + } + + _, err := con.Rpc.UpdateGithubConfig(con.Context(), githubConfig) + if err != nil { + return fmt.Errorf("failed to update GitHub config: %w", err) + } + + con.Log.Importantf("GitHub Actions configuration updated successfully\n") + con.Log.Infof(" Owner: %s\n", owner) + con.Log.Infof(" Repo: %s\n", repo) + con.Log.Infof(" Token: %s***\n", token[:minInt(4, len(token))]) + if workflowFile != "" { + con.Log.Infof(" Workflow: %s\n", workflowFile) + } + + return nil +} + +// executeNotifyConfig configures notification channels +func executeNotifyConfig(con *core.Console, result *wizardfw.WizardResult) error { + notify := &clientpb.Notify{} + hasConfig := false + + // Telegram + if getBool(result, "telegram_enable") { + notify.TelegramEnable = true + notify.TelegramApiKey = getString(result, "telegram_token") + if chatID := getString(result, "telegram_chat_id"); chatID != "" { + // Parse chat ID as int64 + var id int64 + fmt.Sscanf(chatID, "%d", &id) + notify.TelegramChatId = id + } + hasConfig = true + } + + // DingTalk + if getBool(result, "dingtalk_enable") { + notify.DingtalkEnable = true + notify.DingtalkToken = getString(result, "dingtalk_token") + notify.DingtalkSecret = getString(result, "dingtalk_secret") + hasConfig = true + } + + // Lark + if getBool(result, "lark_enable") { + notify.LarkEnable = true + notify.LarkWebhookUrl = getString(result, "lark_webhook") + hasConfig = true + } + + // ServerChan + if getBool(result, "serverchan_enable") { + notify.ServerchanEnable = true + notify.ServerchanUrl = getString(result, "serverchan_url") + hasConfig = true + } + + // PushPlus + if getBool(result, "pushplus_enable") { + notify.PushplusEnable = true + notify.PushplusToken = getString(result, "pushplus_token") + notify.PushplusTopic = getString(result, "pushplus_topic") + hasConfig = true + } + + if !hasConfig { + con.Log.Warnf("No notification channels enabled\n") + return nil + } + + _, err := con.Rpc.UpdateNotifyConfig(con.Context(), notify) + if err != nil { + return fmt.Errorf("failed to update notification config: %w", err) + } + + con.Log.Importantf("Notification configuration updated successfully\n") + if notify.TelegramEnable { + con.Log.Infof(" Telegram: enabled\n") + } + if notify.DingtalkEnable { + con.Log.Infof(" DingTalk: enabled\n") + } + if notify.LarkEnable { + con.Log.Infof(" Lark: enabled\n") + } + if notify.ServerchanEnable { + con.Log.Infof(" ServerChan: enabled\n") + } + if notify.PushplusEnable { + con.Log.Infof(" PushPlus: enabled\n") + } + + return nil +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/client/command/wizard/executors_test.go b/client/command/wizard/executors_test.go new file mode 100644 index 00000000..8f58f729 --- /dev/null +++ b/client/command/wizard/executors_test.go @@ -0,0 +1,168 @@ +package wizard + +import ( + "testing" + + "github.com/chainreactors/IoM-go/consts" +) + +func TestParseAddress(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + wantAddr string + wantTCP bool + wantHTTP bool + wantTLS bool + wantSNI string + wantErr bool + }{ + { + name: "http default port", + input: "http://example.com", + wantAddr: "example.com:80", + wantHTTP: true, + }, + { + name: "https default port and sni", + input: "https://example.com", + wantAddr: "example.com:443", + wantHTTP: true, + wantTLS: true, + wantSNI: "example.com", + }, + { + name: "tcp default port", + input: "tcp://example.com", + wantAddr: "example.com:5001", + wantTCP: true, + }, + { + name: "tcp+tls default port and sni", + input: "tcp+tls://example.com", + wantAddr: "example.com:5001", + wantTCP: true, + wantTLS: true, + wantSNI: "example.com", + }, + { + name: "raw host defaults to tcp", + input: "example.com", + wantAddr: "example.com:5001", + wantTCP: true, + }, + { + name: "raw ipv6 defaults to tcp", + input: "::1", + wantAddr: "[::1]:5001", + wantTCP: true, + }, + { + name: "raw ipv6 with port", + input: "[::1]:6000", + wantAddr: "[::1]:6000", + wantTCP: true, + }, + { + name: "unsupported scheme", + input: "ftp://example.com", + wantErr: true, + }, + { + name: "http path not allowed", + input: "http://example.com/foo", + wantErr: true, + }, + { + name: "raw address with path not allowed", + input: "example.com/foo", + wantErr: true, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + target, err := parseAddress(tt.input) + if tt.wantErr { + if err == nil { + t.Fatalf("expected error, got nil (target=%+v)", target) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if target == nil { + t.Fatalf("expected target, got nil") + } + if target.Address != tt.wantAddr { + t.Fatalf("address mismatch: got %q want %q", target.Address, tt.wantAddr) + } + if (target.TCP != nil) != tt.wantTCP { + t.Fatalf("tcp mismatch: got %v want %v", target.TCP != nil, tt.wantTCP) + } + if (target.Http != nil) != tt.wantHTTP { + t.Fatalf("http mismatch: got %v want %v", target.Http != nil, tt.wantHTTP) + } + if (target.TLS != nil) != tt.wantTLS { + t.Fatalf("tls mismatch: got %v want %v", target.TLS != nil, tt.wantTLS) + } + if tt.wantTLS && target.TLS != nil && target.TLS.SNI != tt.wantSNI { + t.Fatalf("sni mismatch: got %q want %q", target.TLS.SNI, tt.wantSNI) + } + }) + } +} + +func TestParseTargets_TrimsAndValidates(t *testing.T) { + t.Parallel() + + targets, err := parseTargets(" http://example.com , tcp://127.0.0.1:5001 , ") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(targets) != 2 { + t.Fatalf("unexpected target count: got %d want %d", len(targets), 2) + } + if targets[0].Address != "example.com:80" { + t.Fatalf("unexpected first address: %q", targets[0].Address) + } + if targets[1].Address != "127.0.0.1:5001" { + t.Fatalf("unexpected second address: %q", targets[1].Address) + } +} + +func TestParsePulseAddress(t *testing.T) { + t.Parallel() + + parsed, err := parsePulseAddress("tcp://example.com") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if parsed.protocol != consts.TCPPipeline { + t.Fatalf("unexpected protocol: got %q want %q", parsed.protocol, consts.TCPPipeline) + } + if parsed.target != "example.com:5001" { + t.Fatalf("unexpected target: got %q", parsed.target) + } + + parsed, err = parsePulseAddress("http://[::1]/") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if parsed.protocol != consts.HTTPPipeline { + t.Fatalf("unexpected protocol: got %q want %q", parsed.protocol, consts.HTTPPipeline) + } + if parsed.target != "[::1]:80" { + t.Fatalf("unexpected target: got %q", parsed.target) + } + + if _, err := parsePulseAddress("https://example.com"); err == nil { + t.Fatalf("expected error for https pulse address, got nil") + } +} diff --git a/client/command/wizard/providers.go b/client/command/wizard/providers.go new file mode 100644 index 00000000..f49030c3 --- /dev/null +++ b/client/command/wizard/providers.go @@ -0,0 +1,186 @@ +package wizard + +import ( + "fmt" + + "github.com/chainreactors/IoM-go/proto/client/clientpb" + "github.com/chainreactors/malice-network/client/core" +) + +// ProfileOptionsProvider returns a function that fetches profile names from the server +func ProfileOptionsProvider() func(ctx interface{}) []string { + return func(ctx interface{}) []string { + con, ok := ctx.(*core.Console) + if !ok || con == nil { + return nil + } + + profiles, err := con.Rpc.GetProfiles(con.Context(), &clientpb.Empty{}) + if err != nil { + return nil + } + + opts := make([]string, 0, len(profiles.Profiles)+1) + opts = append(opts, "") // Allow empty (use default profile) + for _, p := range profiles.Profiles { + opts = append(opts, p.Name) + } + return opts + } +} + +// ListenerOptionsProvider returns a function that fetches listener IDs from the server +func ListenerOptionsProvider() func(ctx interface{}) []string { + return func(ctx interface{}) []string { + con, ok := ctx.(*core.Console) + if !ok || con == nil { + return nil + } + + listeners, err := con.Rpc.GetListeners(con.Context(), &clientpb.Empty{}) + if err != nil { + return nil + } + + opts := make([]string, 0, len(listeners.Listeners)) + for _, l := range listeners.Listeners { + opts = append(opts, l.Id) + } + return opts + } +} + +// PipelineOptionsProvider returns a function that fetches pipeline names from the server +func PipelineOptionsProvider() func(ctx interface{}) []string { + return func(ctx interface{}) []string { + con, ok := ctx.(*core.Console) + if !ok || con == nil { + return nil + } + + pipelines, err := con.Rpc.ListJobs(con.Context(), &clientpb.Empty{}) + if err != nil { + return nil + } + + opts := make([]string, 0, len(pipelines.GetPipelines())+1) + opts = append(opts, "") // Allow empty + for _, p := range pipelines.GetPipelines() { + opts = append(opts, p.Name) + } + return opts + } +} + +// ArtifactOptionsProvider returns a function that fetches artifact names from the server +// filterType can be used to filter by artifact type (e.g., "beacon", "pulse", "module") +func ArtifactOptionsProvider(filterType string) func(ctx interface{}) []string { + return func(ctx interface{}) []string { + con, ok := ctx.(*core.Console) + if !ok || con == nil { + return nil + } + + artifacts, err := con.Rpc.ListArtifact(con.Context(), &clientpb.Empty{}) + if err != nil { + return nil + } + + opts := make([]string, 0, len(artifacts.Artifacts)+1) + opts = append(opts, "") // Allow empty (no artifact) + for _, a := range artifacts.Artifacts { + if filterType == "" || a.Type == filterType { + opts = append(opts, fmt.Sprintf("%d", a.Id)) + } + } + return opts + } +} + +// AddressOptionsProvider returns a function that fetches C2 addresses from pipelines +func AddressOptionsProvider() func(ctx interface{}) []string { + return func(ctx interface{}) []string { + con, ok := ctx.(*core.Console) + if !ok || con == nil { + return nil + } + + pipelines, err := con.Rpc.ListJobs(con.Context(), &clientpb.Empty{}) + if err != nil { + return nil + } + + opts := make([]string, 0) + opts = append(opts, "") // Allow empty (manual input) + seen := make(map[string]bool) + + for _, p := range pipelines.GetPipelines() { + var addr string + switch body := p.Body.(type) { + case *clientpb.Pipeline_Tcp: + tcp := body.Tcp + if tcp.Host != "" && tcp.Port != 0 { + addr = fmt.Sprintf("%s:%d", tcp.Host, tcp.Port) + } + case *clientpb.Pipeline_Http: + http := body.Http + if http.Host != "" && http.Port != 0 { + schema := "http" + if p.Tls != nil && p.Tls.Enable { + schema = "https" + } + addr = fmt.Sprintf("%s://%s:%d", schema, http.Host, http.Port) + } + } + if addr != "" && !seen[addr] { + seen[addr] = true + opts = append(opts, addr) + } + } + return opts + } +} + +// PulseAddressOptionsProvider returns a function that fetches stage-0 compatible addresses. +// Pulse currently supports only `http://` and `tcp://` targets; HTTPS is intentionally excluded. +func PulseAddressOptionsProvider() func(ctx interface{}) []string { + return func(ctx interface{}) []string { + con, ok := ctx.(*core.Console) + if !ok || con == nil { + return nil + } + + pipelines, err := con.Rpc.ListJobs(con.Context(), &clientpb.Empty{}) + if err != nil { + return nil + } + + opts := make([]string, 0) + opts = append(opts, "") // Allow empty (manual input) + seen := make(map[string]bool) + + for _, p := range pipelines.GetPipelines() { + var addr string + switch body := p.Body.(type) { + case *clientpb.Pipeline_Tcp: + tcp := body.Tcp + if tcp.Host != "" && tcp.Port != 0 { + addr = fmt.Sprintf("tcp://%s:%d", tcp.Host, tcp.Port) + } + case *clientpb.Pipeline_Http: + http := body.Http + if http.Host != "" && http.Port != 0 { + if p.Tls != nil && p.Tls.Enable { + continue + } + addr = fmt.Sprintf("http://%s:%d", http.Host, http.Port) + } + } + if addr != "" && !seen[addr] { + seen[addr] = true + opts = append(opts, addr) + } + } + return opts + } +} diff --git a/client/command/wizard_flag.go b/client/command/wizard_flag.go new file mode 100644 index 00000000..b537ed1a --- /dev/null +++ b/client/command/wizard_flag.go @@ -0,0 +1,151 @@ +package command + +import ( + "fmt" + + "github.com/chainreactors/malice-network/client/core" + "github.com/chainreactors/malice-network/client/wizard" + "github.com/spf13/cobra" +) + +// WizardFlagName is the name of the global wizard flag +const WizardFlagName = "wizard" + +// RegisterWizardFlag registers the global --wizard flag on the root command +func RegisterWizardFlag(rootCmd *cobra.Command) { + rootCmd.PersistentFlags().Bool(WizardFlagName, false, "Start interactive wizard mode") +} + +// WrapWithWizardSupport wraps the PersistentPreRunE to add wizard support +// Returns new pre and post runner functions +func WrapWithWizardSupport( + con *core.Console, + originalPre, originalPost func(cmd *cobra.Command, args []string) error, +) (pre, post func(cmd *cobra.Command, args []string) error) { + + pre = func(cmd *cobra.Command, args []string) error { + // Check if wizard mode is enabled + wizardMode, _ := cmd.Flags().GetBool(WizardFlagName) + if !wizardMode { + // Not wizard mode, execute original logic + if originalPre != nil { + return originalPre(cmd, args) + } + return nil + } + + // Wizard mode: convert command flags to wizard + wiz := wizard.CobraToWizard(cmd) + if wiz == nil { + return fmt.Errorf("cannot create wizard for command %s", cmd.Name()) + } + + // Check if there are any fields to display + if len(wiz.Fields) == 0 { + cmd.Printf("Command %s has no configurable parameters\n", cmd.Name()) + // Continue with original PreRunE + if originalPre != nil { + return originalPre(cmd, args) + } + return nil + } + + // Prepare dynamic options if console is available + if con != nil { + wiz.PrepareOptions(con) + } + + // Run wizard + runner := wizard.NewRunner(wiz) + result, err := runner.Run() + if err != nil { + return fmt.Errorf("wizard cancelled or failed: %w", err) + } + + // Apply wizard results to flags + if err := wizard.ApplyWizardResultToFlags(cmd, result); err != nil { + return fmt.Errorf("failed to apply wizard result: %w", err) + } + + // Execute original PreRunE (if any) + if originalPre != nil { + return originalPre(cmd, args) + } + return nil + } + + post = originalPost + return pre, post +} + +// ShouldRunWizard checks if the command should run in wizard mode +func ShouldRunWizard(cmd *cobra.Command) bool { + wizardMode, _ := cmd.Flags().GetBool(WizardFlagName) + return wizardMode +} + +// HandleWizardFlag handles the --wizard flag for console mode +// This is called in PersistentPreRunE for interactive console commands +func HandleWizardFlag(cmd *cobra.Command, con *core.Console) error { + // Check if wizard mode is enabled + wizardMode, _ := cmd.Flags().GetBool(WizardFlagName) + if !wizardMode { + return nil + } + + // Wizard mode: convert command flags to wizard + wiz := wizard.CobraToWizard(cmd) + if wiz == nil { + return fmt.Errorf("cannot create wizard for command %s", cmd.Name()) + } + + // Check if there are any fields to display + if len(wiz.Fields) == 0 { + cmd.Printf("Command %s has no configurable parameters\n", cmd.Name()) + return nil + } + + // Prepare dynamic options if console is available + if con != nil { + wiz.PrepareOptions(con) + } + + // Run wizard + runner := wizard.NewRunner(wiz) + result, err := runner.Run() + if err != nil { + return fmt.Errorf("wizard cancelled or failed: %w", err) + } + + // Apply wizard results to flags + if err := wizard.ApplyWizardResultToFlags(cmd, result); err != nil { + return fmt.Errorf("failed to apply wizard result: %w", err) + } + + return nil +} + +// RunWizardForCommand runs wizard for a specific command and applies results +// This can be used by subcommands that want to handle wizard mode themselves +func RunWizardForCommand(cmd *cobra.Command, con *core.Console) error { + wiz := wizard.CobraToWizard(cmd) + if wiz == nil { + return fmt.Errorf("cannot create wizard for command %s", cmd.Name()) + } + + if len(wiz.Fields) == 0 { + return nil // No fields to configure + } + + if con != nil { + wiz.PrepareOptions(con) + } + + runner := wizard.NewRunner(wiz) + result, err := runner.Run() + if err != nil { + return err + } + + return wizard.ApplyWizardResultToFlags(cmd, result) +} diff --git a/client/core/ai.go b/client/core/ai.go new file mode 100644 index 00000000..cf2fb542 --- /dev/null +++ b/client/core/ai.go @@ -0,0 +1,911 @@ +package core + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "regexp" + "strings" + "time" + + "github.com/chainreactors/malice-network/client/assets" +) + +// AIClient handles communication with AI APIs (OpenAI and Claude) +type AIClient struct { + settings *assets.AISettings + client *http.Client +} + +// NewAIClient creates a new AI client +func NewAIClient(settings *assets.AISettings) *AIClient { + timeout := 30 + if settings != nil && settings.Timeout > 0 { + timeout = settings.Timeout + } + return &AIClient{ + settings: settings, + client: &http.Client{ + Timeout: time.Duration(timeout) * time.Second, + }, + } +} + +// Message represents a chat message +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// OpenAI API structures +type OpenAIChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float64 `json:"temperature,omitempty"` + Stream bool `json:"stream,omitempty"` +} + +type OpenAIChatResponse struct { + ID string `json:"id"` + Choices []struct { + Message Message `json:"message"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` + Error *struct { + Message string `json:"message"` + Type string `json:"type"` + } `json:"error,omitempty"` +} + +// Claude API structures +type ClaudeChatRequest struct { + Model string `json:"model"` + MaxTokens int `json:"max_tokens"` + System string `json:"system,omitempty"` + Messages []ClaudeMessage `json:"messages"` +} + +type ClaudeMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +type ClaudeChatResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"content"` + StopReason string `json:"stop_reason"` + Error *struct { + Type string `json:"type"` + Message string `json:"message"` + } `json:"error,omitempty"` +} + +// CommandSuggestion represents a command extracted from AI response +type CommandSuggestion struct { + Command string + Description string +} + +// Ask sends a question to the AI with context +func (c *AIClient) Ask(ctx context.Context, question string, history []string) (string, error) { + if c.settings == nil || !c.settings.Enable { + return "", fmt.Errorf("AI is not enabled. Use 'ai-config --enable' to enable it") + } + + if c.settings.APIKey == "" { + return "", fmt.Errorf("AI API key is not configured. Use 'ai-config --api-key ' to set it") + } + + systemPrompt := c.buildSystemPrompt(history) + + switch strings.ToLower(c.settings.Provider) { + case "claude", "anthropic": + return c.askClaude(ctx, systemPrompt, question) + default: // openai and compatible + return c.askOpenAI(ctx, systemPrompt, question) + } +} + +// AskPrediction sends a request optimized for low-latency inline predictions. +func (c *AIClient) AskPrediction(ctx context.Context, question string, history []string) (string, error) { + if c.settings == nil || !c.settings.Enable { + return "", fmt.Errorf("AI is not enabled. Use 'ai-config --enable' to enable it") + } + + if c.settings.APIKey == "" { + return "", fmt.Errorf("AI API key is not configured. Use 'ai-config --api-key ' to set it") + } + + // Keep the context small to minimize latency/cost. + const maxHistory = 5 + if len(history) > maxHistory { + history = history[len(history)-maxHistory:] + } + + systemPrompt := c.buildPredictionSystemPrompt(history) + + maxTokens := 64 + if c.settings.MaxTokens > 0 && c.settings.MaxTokens < maxTokens { + maxTokens = c.settings.MaxTokens + } + + switch strings.ToLower(c.settings.Provider) { + case "claude", "anthropic": + return c.askClaudeWith(ctx, systemPrompt, question, maxTokens) + default: // openai and compatible + return c.askOpenAIWith(ctx, systemPrompt, question, maxTokens, 0.2) + } +} + +func (c *AIClient) buildSystemPrompt(history []string) string { + var sb strings.Builder + sb.WriteString("You are an AI assistant for IoM (Malice Network), a C2 framework. ") + sb.WriteString("Help users with commands, security operations, and answer questions. ") + sb.WriteString("Be concise and provide actionable suggestions when possible.\n\n") + + sb.WriteString("When suggesting commands, wrap them in backticks like `command`. ") + sb.WriteString("This helps users identify executable commands.\n\n") + + sb.WriteString("IMPORTANT: Use EXACT command names as listed below. Do NOT use plural forms or variations. ") + sb.WriteString("For example, use `session` NOT `sessions`, use `listener` NOT `listeners`.\n\n") + + if len(history) > 0 { + sb.WriteString("Recent command history:\n") + for _, cmd := range history { + sb.WriteString(fmt.Sprintf("- %s\n", cmd)) + } + sb.WriteString("\n") + } + + sb.WriteString("Available commands (use these EXACT names):\n") + sb.WriteString("- session: List and manage sessions (NOT 'sessions')\n") + sb.WriteString("- listener: List listeners in server (NOT 'listeners')\n") + sb.WriteString("- use : Switch to a session\n") + sb.WriteString("- ps: List processes\n") + sb.WriteString("- ls, cd, pwd: File system navigation\n") + sb.WriteString("- download, upload: File transfer\n") + sb.WriteString("- execute, shell, run: Run commands on target\n") + sb.WriteString("- job: List jobs\n") + sb.WriteString("- pipeline: Manage pipelines\n") + sb.WriteString("- build: Build implants\n") + + return sb.String() +} + +func (c *AIClient) buildPredictionSystemPrompt(history []string) string { + var sb strings.Builder + sb.WriteString("You are an autocomplete engine for IoM (Malice Network).\n") + sb.WriteString("Return ONLY the next argument/value wrapped in backticks (example: `--help`).\n") + sb.WriteString("If unsure, return an empty response.\n\n") + + if len(history) > 0 { + sb.WriteString("Recent command history:\n") + for _, cmd := range history { + sb.WriteString(fmt.Sprintf("- %s\n", cmd)) + } + sb.WriteString("\n") + } + + return sb.String() +} + +// doRequest sends an HTTP POST request and returns the response body. +func (c *AIClient) doRequest(ctx context.Context, endpoint string, headers map[string]string, body []byte) ([]byte, int, error) { + httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body)) + if err != nil { + return nil, 0, fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "application/json") + for k, v := range headers { + httpReq.Header.Set(k, v) + } + + resp, err := c.client.Do(httpReq) + if err != nil { + return nil, 0, fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, resp.StatusCode, fmt.Errorf("failed to read response: %w", err) + } + + return respBody, resp.StatusCode, nil +} + +// buildEndpoint constructs the API endpoint URL with the given suffix. +func (c *AIClient) buildEndpoint(suffix string) (string, error) { + base := strings.TrimSuffix(strings.TrimSpace(c.settings.Endpoint), "/") + if base == "" { + return "", fmt.Errorf("AI endpoint is not configured. Use 'ai-config --endpoint ' to set it") + } + if !strings.HasSuffix(base, suffix) { + return base + suffix, nil + } + return base, nil +} + +func (c *AIClient) askOpenAI(ctx context.Context, systemPrompt, question string) (string, error) { + return c.askOpenAIWith(ctx, systemPrompt, question, c.settings.MaxTokens, 0.7) +} + +func (c *AIClient) askOpenAIWith(ctx context.Context, systemPrompt, question string, maxTokens int, temperature float64) (string, error) { + if maxTokens <= 0 { + maxTokens = c.settings.MaxTokens + } + if temperature < 0 { + temperature = 0.7 + } + + req := OpenAIChatRequest{ + Model: c.settings.Model, + Messages: []Message{{Role: "system", Content: systemPrompt}, {Role: "user", Content: question}}, + MaxTokens: maxTokens, + Temperature: temperature, + } + + body, err := json.Marshal(req) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint, err := c.buildEndpoint("/chat/completions") + if err != nil { + return "", err + } + + respBody, statusCode, err := c.doRequest(ctx, endpoint, map[string]string{ + "Authorization": "Bearer " + c.settings.APIKey, + }, body) + if err != nil { + return "", err + } + + var chatResp OpenAIChatResponse + if err := json.Unmarshal(respBody, &chatResp); err != nil { + if statusCode < 200 || statusCode >= 300 { + return "", fmt.Errorf("API error (%d): %s", statusCode, strings.TrimSpace(string(respBody))) + } + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if statusCode < 200 || statusCode >= 300 { + if chatResp.Error != nil { + return "", fmt.Errorf("API error (%d): %s", statusCode, chatResp.Error.Message) + } + return "", fmt.Errorf("API error (%d): %s", statusCode, strings.TrimSpace(string(respBody))) + } + + if chatResp.Error != nil { + return "", fmt.Errorf("API error: %s", chatResp.Error.Message) + } + + if len(chatResp.Choices) == 0 { + return "", fmt.Errorf("no response from AI") + } + + return chatResp.Choices[0].Message.Content, nil +} + +func (c *AIClient) askClaude(ctx context.Context, systemPrompt, question string) (string, error) { + return c.askClaudeWith(ctx, systemPrompt, question, c.settings.MaxTokens) +} + +func (c *AIClient) askClaudeWith(ctx context.Context, systemPrompt, question string, maxTokens int) (string, error) { + if maxTokens <= 0 { + maxTokens = c.settings.MaxTokens + } + if maxTokens <= 0 { + maxTokens = 256 + } + + req := ClaudeChatRequest{ + Model: c.settings.Model, + MaxTokens: maxTokens, + System: systemPrompt, + Messages: []ClaudeMessage{{Role: "user", Content: question}}, + } + + body, err := json.Marshal(req) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint, err := c.buildEndpoint("/messages") + if err != nil { + return "", err + } + + respBody, statusCode, err := c.doRequest(ctx, endpoint, map[string]string{ + "x-api-key": c.settings.APIKey, + "anthropic-version": "2023-06-01", + }, body) + if err != nil { + return "", err + } + + var chatResp ClaudeChatResponse + if err := json.Unmarshal(respBody, &chatResp); err != nil { + if statusCode < 200 || statusCode >= 300 { + return "", fmt.Errorf("API error (%d): %s", statusCode, strings.TrimSpace(string(respBody))) + } + return "", fmt.Errorf("failed to parse response: %w", err) + } + + if statusCode < 200 || statusCode >= 300 { + if chatResp.Error != nil { + return "", fmt.Errorf("API error (%d): %s", statusCode, chatResp.Error.Message) + } + return "", fmt.Errorf("API error (%d): %s", statusCode, strings.TrimSpace(string(respBody))) + } + + if chatResp.Error != nil { + return "", fmt.Errorf("API error: %s", chatResp.Error.Message) + } + + if len(chatResp.Content) == 0 { + return "", fmt.Errorf("no response from AI") + } + + var result strings.Builder + for _, content := range chatResp.Content { + if content.Type == "text" { + result.WriteString(content.Text) + } + } + + return result.String(), nil +} + +// ParseCommandSuggestions extracts command suggestions from AI response +// Commands are expected to be wrapped in backticks like `command` +func ParseCommandSuggestions(response string) []CommandSuggestion { + var suggestions []CommandSuggestion + + // Match single backtick commands: `command` + singlePattern := regexp.MustCompile("`([^`\n]+)`") + matches := singlePattern.FindAllStringSubmatch(response, -1) + + seen := make(map[string]bool) + for _, match := range matches { + if len(match) > 1 { + cmd := strings.TrimSpace(match[1]) + // Skip if it looks like code/variable rather than command + if strings.Contains(cmd, "=") || strings.HasPrefix(cmd, "$") { + continue + } + // Skip shell escape syntax (! prefix) + if strings.HasPrefix(cmd, "!") { + continue + } + if !seen[cmd] { + seen[cmd] = true + suggestions = append(suggestions, CommandSuggestion{ + Command: cmd, + Description: "", + }) + } + } + } + + return suggestions +} + +// FormatResponseWithCommands formats the AI response with numbered command suggestions +func FormatResponseWithCommands(response string, commands []CommandSuggestion) string { + if len(commands) == 0 { + return response + } + + var sb strings.Builder + sb.WriteString(response) + sb.WriteString("\n\n") + sb.WriteString("Suggested commands:\n") + + for i, cmd := range commands { + sb.WriteString(fmt.Sprintf(" [%d] %s\n", i+1, cmd.Command)) + } + + return sb.String() +} + +// OpenAI streaming response structures +type OpenAIStreamChunk struct { + Choices []struct { + Delta struct { + Content string `json:"content"` + } `json:"delta"` + FinishReason string `json:"finish_reason"` + } `json:"choices"` +} + +// Claude streaming response structures +type ClaudeStreamEvent struct { + Type string `json:"type"` + Index int `json:"index,omitempty"` + Delta *struct { + Type string `json:"type"` + Text string `json:"text"` + } `json:"delta,omitempty"` +} + +// AskStream sends a question to the AI and streams the response +func (c *AIClient) AskStream(ctx context.Context, question string, history []string, onChunk func(chunk string)) (string, error) { + if c.settings == nil || !c.settings.Enable { + return "", fmt.Errorf("AI is not enabled. Use 'ai-config --enable' to enable it") + } + + if c.settings.APIKey == "" { + return "", fmt.Errorf("AI API key is not configured. Use 'ai-config --api-key ' to set it") + } + + systemPrompt := c.buildSystemPrompt(history) + + switch strings.ToLower(c.settings.Provider) { + case "claude", "anthropic": + return c.askClaudeStream(ctx, systemPrompt, question, onChunk) + default: // openai and compatible + return c.askOpenAIStream(ctx, systemPrompt, question, onChunk) + } +} + +func (c *AIClient) askOpenAIStream(ctx context.Context, systemPrompt, question string, onChunk func(chunk string)) (string, error) { + req := OpenAIChatRequest{ + Model: c.settings.Model, + Messages: []Message{{Role: "system", Content: systemPrompt}, {Role: "user", Content: question}}, + MaxTokens: c.settings.MaxTokens, + Temperature: 0.7, + Stream: true, + } + + body, err := json.Marshal(req) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint, err := c.buildEndpoint("/chat/completions") + if err != nil { + return "", err + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Authorization", "Bearer "+c.settings.APIKey) + + resp, err := c.client.Do(httpReq) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + respBody, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + + var fullResponse strings.Builder + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + if data == "[DONE]" { + break + } + + var chunk OpenAIStreamChunk + if err := json.Unmarshal([]byte(data), &chunk); err != nil { + continue + } + + if len(chunk.Choices) > 0 && chunk.Choices[0].Delta.Content != "" { + content := chunk.Choices[0].Delta.Content + fullResponse.WriteString(content) + if onChunk != nil { + onChunk(content) + } + } + } + + if err := scanner.Err(); err != nil { + return fullResponse.String(), fmt.Errorf("stream read error: %w", err) + } + + return fullResponse.String(), nil +} + +func (c *AIClient) askClaudeStream(ctx context.Context, systemPrompt, question string, onChunk func(chunk string)) (string, error) { + reqBody := map[string]interface{}{ + "model": c.settings.Model, + "max_tokens": c.settings.MaxTokens, + "system": systemPrompt, + "messages": []ClaudeMessage{{Role: "user", Content: question}}, + "stream": true, + } + + body, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + endpoint, err := c.buildEndpoint("/messages") + if err != nil { + return "", err + } + + httpReq, err := http.NewRequestWithContext(ctx, "POST", endpoint, bytes.NewReader(body)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("x-api-key", c.settings.APIKey) + httpReq.Header.Set("anthropic-version", "2023-06-01") + + resp, err := c.client.Do(httpReq) + if err != nil { + return "", fmt.Errorf("request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + respBody, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("API error (%d): %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + + var fullResponse strings.Builder + scanner := bufio.NewScanner(resp.Body) + + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data: ") { + continue + } + + data := strings.TrimPrefix(line, "data: ") + + var event ClaudeStreamEvent + if err := json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + if event.Type == "content_block_delta" && event.Delta != nil && event.Delta.Text != "" { + fullResponse.WriteString(event.Delta.Text) + if onChunk != nil { + onChunk(event.Delta.Text) + } + } + + if event.Type == "message_stop" { + break + } + } + + if err := scanner.Err(); err != nil { + return fullResponse.String(), fmt.Errorf("stream read error: %w", err) + } + + return fullResponse.String(), nil +} + +// AICompletionEngine manages AI completions with caching and validation +type AICompletionEngine struct { + client *AIClient + cache *AICompletionCache + validator *CommandValidator +} + +// NewAICompletionEngine creates a new completion engine +func NewAICompletionEngine(client *AIClient, cache *AICompletionCache, validator *CommandValidator) *AICompletionEngine { + return &AICompletionEngine{ + client: client, + cache: cache, + validator: validator, + } +} + +// SmartComplete provides fast AI completion with caching and validation +func (e *AICompletionEngine) SmartComplete(ctx context.Context, input string, history []string, menu string) ([]string, error) { + input = strings.TrimSpace(input) + if input == "" { + return nil, nil + } + + normalizedInput := strings.TrimSpace(strings.ToLower(input)) + + // Step 1: Check cache first (instant) + if e.cache != nil { + if cached, ok := e.cache.GetScoped(menu, input); ok { + // Filter out suggestions that match input exactly + filtered := filterSameAsInput(cached, normalizedInput) + if len(filtered) > 0 { + return filtered, nil + } + } + // Try prefix match + if cached, ok := e.cache.GetPrefixScoped(menu, input); ok { + filtered := filterSameAsInput(cached, normalizedInput) + if len(filtered) > 0 { + return filtered, nil + } + } + } + + // Step 2: Call AI + if e.client == nil || e.client.settings == nil || !e.client.settings.Enable { + return nil, fmt.Errorf("AI not enabled") + } + + prompt := e.buildCompletionPrompt(input, menu) + response, err := e.client.Ask(ctx, prompt, history) + if err != nil { + return nil, err + } + + // Step 3: Parse and validate commands + suggestions := ParseCommandSuggestions(response) + validSuggestions := make([]string, 0, len(suggestions)) + seen := make(map[string]bool) + + // Prepare input prefix for stripping (handle both "cmd" and "cmd " cases) + inputLower := strings.ToLower(input) + inputTrimmed := strings.TrimSpace(inputLower) + inputWithSpace := inputLower + if !strings.HasSuffix(inputLower, " ") { + inputWithSpace = inputLower + " " + } + + // Check if user is typing a subcommand (input ends with space) + isSubcommandContext := strings.HasSuffix(input, " ") + + for _, suggestion := range suggestions { + cmd := suggestion.Command + + // Skip suggestions that are identical to the input (no point suggesting what user already typed) + if strings.TrimSpace(strings.ToLower(cmd)) == normalizedInput { + continue + } + + // Convert full command to completion by stripping input prefix + cmdLower := strings.ToLower(cmd) + completionPart := cmd + wasStripped := false + + // Case 1: Suggestion starts with user input (e.g., input="website ", suggestion="website add") + if strings.HasPrefix(cmdLower, inputWithSpace) { + completionPart = strings.TrimSpace(cmd[len(inputWithSpace):]) + wasStripped = true + } else if strings.HasPrefix(cmdLower, inputTrimmed+" ") { + completionPart = strings.TrimSpace(cmd[len(inputTrimmed)+1:]) + wasStripped = true + } else { + // Case 2: User input is in the middle (e.g., input="website ", suggestion="client website add") + // Find the input command in the suggestion and extract what follows + idx := strings.Index(cmdLower, inputTrimmed+" ") + if idx >= 0 { + completionPart = strings.TrimSpace(cmd[idx+len(inputTrimmed)+1:]) + wasStripped = true + } + } + + // Skip empty completions + if completionPart == "" { + continue + } + + // Validate and fix if validator is available + if e.validator != nil { + // Determine full command for validation + var fullCmd string + if isSubcommandContext && (wasStripped || !strings.Contains(completionPart, " ")) { + // User is typing subcommand (e.g., "website "), prepend input for validation + fullCmd = strings.TrimSpace(input) + " " + completionPart + } else { + // User is typing command prefix (e.g., "w"), validate as-is + fullCmd = completionPart + } + + fixed, valid := e.validator.ValidateAndFix(fullCmd) + if valid { + // Also skip if fixed version matches input + if strings.TrimSpace(strings.ToLower(fixed)) == normalizedInput { + continue + } + // Filter by menu: only allow commands available in the current menu + if menu != "" && !e.validator.IsCommandAllowedInMenu(menu, fixed) { + continue + } + + // Always store just the completion part for display + if !seen[completionPart] { + seen[completionPart] = true + validSuggestions = append(validSuggestions, completionPart) + } + } + } else { + // Without validator, only accept simple command-like strings + if !seen[completionPart] { + seen[completionPart] = true + validSuggestions = append(validSuggestions, completionPart) + } + } + } + + // Limit to top 10 suggestions + if len(validSuggestions) > 10 { + validSuggestions = validSuggestions[:10] + } + + // Step 4: Cache the result + if e.cache != nil && len(validSuggestions) > 0 { + e.cache.SetScoped(menu, input, validSuggestions) + } + + return validSuggestions, nil +} + +func (e *AICompletionEngine) buildCompletionPrompt(input string, menu string) string { + var sb strings.Builder + + sb.WriteString("Complete the following partial command. Return ONLY the completion part.\n\n") + sb.WriteString("RULES:\n") + sb.WriteString("1. Return ONLY the part that should be appended to complete the command\n") + sb.WriteString("2. Do NOT repeat what the user has already typed\n") + sb.WriteString("3. Return up to 10 suggestions, ONE completion per line\n") + sb.WriteString("4. Wrap EACH completion in backticks like `subcommand` or `subcommand arg`\n") + sb.WriteString("5. If input is a command with subcommands, suggest its subcommand names only\n") + sb.WriteString("6. If input looks like a typo, suggest the correct full command\n") + sb.WriteString("7. ONLY suggest commands from the AVAILABLE COMMANDS list below\n\n") + sb.WriteString("EXAMPLE:\n") + sb.WriteString("- Input: 'website ' -> suggest `add`, `list`, `remove` (NOT `website add`)\n") + sb.WriteString("- Input: 'websi' -> suggest `website` (typo correction)\n\n") + + // Add available commands if validator is present + if e.validator != nil { + // Only use commands available in the current menu + commands := e.validator.GetCommandsForMenu(menu) + if len(commands) > 0 { + if strings.TrimSpace(menu) != "" { + sb.WriteString(fmt.Sprintf("CURRENT MENU: %s (only commands below are available)\n\n", menu)) + } + sb.WriteString("AVAILABLE COMMANDS:\n") + for _, cmd := range commands { + sb.WriteString(fmt.Sprintf("- %s\n", cmd)) + } + sb.WriteString("\n") + } + } + + sb.WriteString(fmt.Sprintf("INPUT: %s\n", input)) + + return sb.String() +} + +// SetValidator updates the command validator +func (e *AICompletionEngine) SetValidator(v *CommandValidator) { + e.validator = v +} + +// SetCache updates the cache +func (e *AICompletionEngine) SetCache(c *AICompletionCache) { + e.cache = c +} + +// filterSameAsInput removes suggestions that exactly match the input +func filterSameAsInput(suggestions []string, normalizedInput string) []string { + result := make([]string, 0, len(suggestions)) + for _, s := range suggestions { + if strings.TrimSpace(strings.ToLower(s)) != normalizedInput { + result = append(result, s) + } + } + return result +} + +// PredictNextArgument predicts the next argument the user is likely to type +func (e *AICompletionEngine) PredictNextArgument(ctx context.Context, input string, history []string, menu string) (string, error) { + if strings.TrimSpace(input) == "" { + return "", nil + } + input = strings.TrimLeft(input, " \t") + + // Check cache first (use a special prefix for predictions) + cacheKey := "predict:" + input + if e.cache != nil { + if cached, ok := e.cache.GetScoped(menu, cacheKey); ok && len(cached) > 0 { + return cached[0], nil + } + } + + // Call AI for prediction + if e.client == nil || e.client.settings == nil || !e.client.settings.Enable { + return "", fmt.Errorf("AI not enabled") + } + + prompt := e.buildPredictionPrompt(input, menu) + response, err := e.client.AskPrediction(ctx, prompt, history) + if err != nil { + return "", err + } + + // Parse prediction from response + prediction := parsePrediction(response) + if prediction == "" { + return "", nil + } + + // Cache the result + if e.cache != nil { + e.cache.SetScoped(menu, cacheKey, []string{prediction}) + } + + return prediction, nil +} + +func (e *AICompletionEngine) buildPredictionPrompt(input string, menu string) string { + var sb strings.Builder + + sb.WriteString("Predict the next argument/value for the current CLI input.\n") + sb.WriteString("Return exactly ONE value wrapped in backticks (example: `--help`). If unsure, return nothing.\n") + if strings.TrimSpace(menu) != "" { + sb.WriteString(fmt.Sprintf("MENU: %s\n", menu)) + } + sb.WriteString(fmt.Sprintf("CURRENT INPUT: %s\n", input)) + sb.WriteString("NEXT ARGUMENT:") + + return sb.String() +} + +// parsePrediction extracts a single prediction from AI response +func parsePrediction(response string) string { + // Look for backtick-wrapped prediction + lines := strings.Split(response, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + // Extract content between backticks + start := strings.Index(line, "`") + if start == -1 { + continue + } + end := strings.Index(line[start+1:], "`") + if end == -1 { + continue + } + prediction := line[start+1 : start+1+end] + prediction = strings.TrimSpace(prediction) + if prediction != "" { + return prediction + } + } + + // Fallback: try to find any reasonable single-word/value + response = strings.TrimSpace(response) + if !strings.Contains(response, " ") && !strings.Contains(response, "\n") && len(response) < 50 { + return response + } + + return "" +} diff --git a/client/core/ai_cache.go b/client/core/ai_cache.go new file mode 100644 index 00000000..3330d0ff --- /dev/null +++ b/client/core/ai_cache.go @@ -0,0 +1,366 @@ +package core + +import ( + "container/list" + "strings" + "sync" + "time" +) + +// AICompletionCache provides high-speed caching for AI completions. +// It supports: +// - exact match lookups +// - prefix-based reuse of cached results +// - TTL expiration +// - LRU eviction +type AICompletionCache struct { + mu sync.Mutex + entries map[string]*list.Element + lru *list.List + maxSize int + ttl time.Duration +} + +type cacheEntry struct { + Key string + Suggestions []string + Timestamp time.Time +} + +const cacheScopeSeparator = "\x00" + +// NewAICompletionCache creates a new cache with specified size and TTL +func NewAICompletionCache(maxSize int, ttl time.Duration) *AICompletionCache { + return &AICompletionCache{ + entries: make(map[string]*list.Element), + lru: list.New(), + maxSize: maxSize, + ttl: ttl, + } +} + +// Get retrieves cached completions for the given input +func (c *AICompletionCache) Get(input string) ([]string, bool) { + key := normalizeCacheKey(input) + if key == "" { + return nil, false + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Try exact match first + el, ok := c.entries[key] + if !ok { + return nil, false + } + + entry := el.Value.(*cacheEntry) + if c.isExpired(entry) { + c.removeElement(el) + return nil, false + } + + c.lru.MoveToFront(el) + return cloneStringSlice(entry.Suggestions), true +} + +// GetScoped retrieves cached completions for the given input under a namespace. +// Scope is typically the active menu name, so caches don't mix across contexts. +func (c *AICompletionCache) GetScoped(scope string, input string) ([]string, bool) { + scopeKey := normalizeCacheKey(scope) + if scopeKey == "" { + return c.Get(input) + } + + inputKey := normalizeCacheKey(input) + if inputKey == "" { + return nil, false + } + + key := scopeKey + cacheScopeSeparator + inputKey + + c.mu.Lock() + defer c.mu.Unlock() + + el, ok := c.entries[key] + if !ok { + return nil, false + } + + entry := el.Value.(*cacheEntry) + if c.isExpired(entry) { + c.removeElement(el) + return nil, false + } + + c.lru.MoveToFront(el) + return cloneStringSlice(entry.Suggestions), true +} + +// GetPrefix retrieves cached completions matching the given prefix +func (c *AICompletionCache) GetPrefix(prefix string) ([]string, bool) { + normalizedPrefix := normalizeCacheKey(prefix) + if normalizedPrefix == "" { + return nil, false + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Find the longest matching prefix + var bestEntry *cacheEntry + var bestEl *list.Element + + for key, el := range c.entries { + if !strings.HasPrefix(normalizedPrefix, key) { + continue + } + + entry := el.Value.(*cacheEntry) + if c.isExpired(entry) { + c.removeElement(el) + continue + } + + if bestEntry == nil || len(key) > len(bestEntry.Key) { + bestEntry = entry + bestEl = el + } + } + + if bestEntry == nil { + return nil, false + } + + filtered := make([]string, 0, len(bestEntry.Suggestions)) + for _, suggestion := range bestEntry.Suggestions { + if strings.HasPrefix(strings.ToLower(suggestion), normalizedPrefix) { + filtered = append(filtered, suggestion) + } + } + if len(filtered) == 0 { + return nil, false + } + + c.lru.MoveToFront(bestEl) + return filtered, true +} + +// GetPrefixScoped retrieves cached completions matching the given prefix under a namespace. +func (c *AICompletionCache) GetPrefixScoped(scope string, prefix string) ([]string, bool) { + scopeKey := normalizeCacheKey(scope) + if scopeKey == "" { + return c.GetPrefix(prefix) + } + + normalizedPrefix := normalizeCacheKey(prefix) + if normalizedPrefix == "" { + return nil, false + } + + scopePrefix := scopeKey + cacheScopeSeparator + + c.mu.Lock() + defer c.mu.Unlock() + + var bestEntry *cacheEntry + var bestEl *list.Element + bestInputKeyLen := -1 + + for key, el := range c.entries { + if !strings.HasPrefix(key, scopePrefix) { + continue + } + + inputKey := strings.TrimPrefix(key, scopePrefix) + if inputKey == "" { + continue + } + + if !strings.HasPrefix(normalizedPrefix, inputKey) { + continue + } + + entry := el.Value.(*cacheEntry) + if c.isExpired(entry) { + c.removeElement(el) + continue + } + + if bestEntry == nil || len(inputKey) > bestInputKeyLen { + bestEntry = entry + bestEl = el + bestInputKeyLen = len(inputKey) + } + } + + if bestEntry == nil { + return nil, false + } + + filtered := make([]string, 0, len(bestEntry.Suggestions)) + for _, suggestion := range bestEntry.Suggestions { + if strings.HasPrefix(strings.ToLower(suggestion), normalizedPrefix) { + filtered = append(filtered, suggestion) + } + } + if len(filtered) == 0 { + return nil, false + } + + c.lru.MoveToFront(bestEl) + return filtered, true +} + +// Set stores completions in the cache +func (c *AICompletionCache) Set(input string, suggestions []string) { + key := normalizeCacheKey(input) + if key == "" || len(suggestions) == 0 { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + if el, ok := c.entries[key]; ok { + entry := el.Value.(*cacheEntry) + entry.Suggestions = cloneStringSlice(suggestions) + entry.Timestamp = time.Now() + c.lru.MoveToFront(el) + return + } + + entry := &cacheEntry{ + Key: key, + Suggestions: cloneStringSlice(suggestions), + Timestamp: time.Now(), + } + el := c.lru.PushFront(entry) + c.entries[key] = el + + c.evictLRU() +} + +// SetScoped stores completions in the cache under a namespace. +func (c *AICompletionCache) SetScoped(scope string, input string, suggestions []string) { + scopeKey := normalizeCacheKey(scope) + if scopeKey == "" { + c.Set(input, suggestions) + return + } + + inputKey := normalizeCacheKey(input) + if inputKey == "" || len(suggestions) == 0 { + return + } + + key := scopeKey + cacheScopeSeparator + inputKey + + c.mu.Lock() + defer c.mu.Unlock() + + if el, ok := c.entries[key]; ok { + entry := el.Value.(*cacheEntry) + entry.Suggestions = cloneStringSlice(suggestions) + entry.Timestamp = time.Now() + c.lru.MoveToFront(el) + return + } + + entry := &cacheEntry{ + Key: key, + Suggestions: cloneStringSlice(suggestions), + Timestamp: time.Now(), + } + el := c.lru.PushFront(entry) + c.entries[key] = el + + c.evictLRU() +} + +func (c *AICompletionCache) evictLRU() { + if c.maxSize <= 0 { + return + } + + for c.lru.Len() > c.maxSize { + el := c.lru.Back() + if el == nil { + return + } + c.removeElement(el) + } +} + +func (c *AICompletionCache) removeElement(el *list.Element) { + if el == nil { + return + } + entry := el.Value.(*cacheEntry) + delete(c.entries, entry.Key) + c.lru.Remove(el) +} + +func (c *AICompletionCache) isExpired(entry *cacheEntry) bool { + if entry == nil { + return true + } + if c.ttl <= 0 { + return false + } + return time.Since(entry.Timestamp) >= c.ttl +} + +// Clear removes all entries from the cache +func (c *AICompletionCache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + c.entries = make(map[string]*list.Element) + c.lru.Init() +} + +// Size returns the current number of entries +func (c *AICompletionCache) Size() int { + c.mu.Lock() + defer c.mu.Unlock() + return len(c.entries) +} + +// Cleanup removes expired entries +func (c *AICompletionCache) Cleanup() { + c.mu.Lock() + defer c.mu.Unlock() + + if c.ttl <= 0 { + return + } + + for el := c.lru.Back(); el != nil; { + prev := el.Prev() + entry := el.Value.(*cacheEntry) + if time.Since(entry.Timestamp) >= c.ttl { + c.removeElement(el) + } + el = prev + } +} + +func normalizeCacheKey(input string) string { + return strings.TrimSpace(strings.ToLower(input)) +} + +func cloneStringSlice(in []string) []string { + if len(in) == 0 { + return nil + } + out := make([]string, 0, len(in)) + for _, s := range in { + s = strings.TrimSpace(s) + if s == "" { + continue + } + out = append(out, s) + } + return out +} diff --git a/client/core/command_validator.go b/client/core/command_validator.go new file mode 100644 index 00000000..07ce3155 --- /dev/null +++ b/client/core/command_validator.go @@ -0,0 +1,378 @@ +package core + +import ( + "sort" + "strings" + "sync" + + "github.com/spf13/cobra" +) + +// CommandValidator validates AI-generated commands against registered commands +type CommandValidator struct { + mu sync.RWMutex + commandMap map[string]bool // command name -> exists + aliases map[string]string // alias -> canonical name + menuMap map[string]map[string]bool +} + +// NewCommandValidator creates a validator from a cobra root command +func NewCommandValidator(rootCmd *cobra.Command) *CommandValidator { + v := &CommandValidator{ + commandMap: make(map[string]bool), + aliases: make(map[string]string), + menuMap: make(map[string]map[string]bool), + } + if rootCmd != nil { + v.buildCommandMap(rootCmd, "") + } + return v +} + +// NewCommandValidatorWithMenu creates a validator from a cobra root command and records menu ownership. +func NewCommandValidatorWithMenu(rootCmd *cobra.Command, menu string) *CommandValidator { + v := NewCommandValidator(nil) + v.AddCommandsFromCobra(rootCmd, menu) + return v +} + +// buildCommandMap recursively builds a map of all available commands +func (v *CommandValidator) buildCommandMap(cmd *cobra.Command, prefix string) { + name := cmd.Name() + fullName := name + if prefix != "" { + fullName = prefix + " " + name + } + + // Register the command + v.commandMap[fullName] = true + v.commandMap[name] = true // Also register just the command name + + // Register aliases + for _, alias := range cmd.Aliases { + v.aliases[alias] = name + if prefix != "" { + v.aliases[prefix+" "+alias] = fullName + } + } + + // Process subcommands + for _, subCmd := range cmd.Commands() { + if !subCmd.Hidden { + v.buildCommandMap(subCmd, fullName) + } + } +} + +// AddCommandsFromCobra adds commands from a cobra root command and records menu ownership. +// Menu ownership is tracked for the root command and its direct subcommands (non-hidden). +func (v *CommandValidator) AddCommandsFromCobra(rootCmd *cobra.Command, menu string) { + if rootCmd == nil { + return + } + + v.mu.Lock() + defer v.mu.Unlock() + + v.buildCommandMap(rootCmd, "") + + if menu == "" { + return + } + + if v.menuMap[menu] == nil { + v.menuMap[menu] = make(map[string]bool) + } + + v.menuMap[menu][rootCmd.Name()] = true + for _, sub := range rootCmd.Commands() { + if sub == nil || sub.Hidden { + continue + } + v.menuMap[menu][sub.Name()] = true + } +} + +// AddCommandWithMenu manually adds a command to the validator and assigns it to a menu. +func (v *CommandValidator) AddCommandWithMenu(menu string, name string, aliases ...string) { + v.mu.Lock() + defer v.mu.Unlock() + + v.commandMap[name] = true + for _, alias := range aliases { + v.aliases[alias] = name + } + + if menu == "" { + return + } + if v.menuMap[menu] == nil { + v.menuMap[menu] = make(map[string]bool) + } + v.menuMap[menu][name] = true +} + +// GetCommandsForMenu returns commands registered for a specific menu. +// If menu is empty, it returns the union of commands across all menus. +func (v *CommandValidator) GetCommandsForMenu(menu string) []string { + v.mu.RLock() + defer v.mu.RUnlock() + + collect := make(map[string]bool) + + if menu == "" { + for _, cmds := range v.menuMap { + for cmd := range cmds { + collect[cmd] = true + } + } + } else if cmds, ok := v.menuMap[menu]; ok { + for cmd := range cmds { + collect[cmd] = true + } + } else { + // Unknown menu: fall back to all single-word commands. + for cmd := range v.commandMap { + if !strings.Contains(cmd, " ") { + collect[cmd] = true + } + } + } + + out := make([]string, 0, len(collect)) + for cmd := range collect { + out = append(out, cmd) + } + sort.Strings(out) + return out +} + +// IsCommandAllowedInMenu checks whether a command line belongs to a given menu. +// If the menu is unknown or empty, it allows all commands. +func (v *CommandValidator) IsCommandAllowedInMenu(menu string, commandLine string) bool { + menu = strings.TrimSpace(menu) + if menu == "" { + return true + } + + commandLine = strings.TrimSpace(commandLine) + if commandLine == "" { + return false + } + + parts := strings.Fields(commandLine) + if len(parts) == 0 { + return false + } + cmdName := parts[0] + + v.mu.RLock() + defer v.mu.RUnlock() + + cmds, ok := v.menuMap[menu] + if !ok || len(cmds) == 0 { + return true + } + + if cmds[cmdName] { + return true + } + + if canonical, ok := v.aliases[cmdName]; ok && cmds[canonical] { + return true + } + + return false +} + +// Validate checks if a command is valid +func (v *CommandValidator) Validate(command string) bool { + v.mu.RLock() + defer v.mu.RUnlock() + + command = strings.TrimSpace(command) + if command == "" { + return false + } + + // Extract the first word (command name) + parts := strings.Fields(command) + if len(parts) == 0 { + return false + } + + cmdName := parts[0] + + // Check exact match + if v.commandMap[cmdName] { + return true + } + + // Check alias + if _, ok := v.aliases[cmdName]; ok { + return true + } + + // Check full command with subcommand + if len(parts) >= 2 { + fullCmd := parts[0] + " " + parts[1] + if v.commandMap[fullCmd] { + return true + } + if _, ok := v.aliases[fullCmd]; ok { + return true + } + } + + return false +} + +// ValidateAndFix attempts to fix an invalid command +func (v *CommandValidator) ValidateAndFix(command string) (string, bool) { + command = strings.TrimSpace(command) + if command == "" { + return "", false + } + + // If already valid, return as-is + if v.Validate(command) { + return command, true + } + + parts := strings.Fields(command) + if len(parts) == 0 { + return "", false + } + + // Try to fix the first word + fixed := v.findSimilar(parts[0]) + if fixed != "" { + parts[0] = fixed + fixedCmd := strings.Join(parts, " ") + if v.Validate(fixedCmd) { + return fixedCmd, true + } + } + + // Try alias resolution + if canonical, ok := v.aliases[parts[0]]; ok { + parts[0] = canonical + return strings.Join(parts, " "), true + } + + return command, false +} + +// findSimilar finds a similar command using Levenshtein distance +func (v *CommandValidator) findSimilar(input string) string { + v.mu.RLock() + defer v.mu.RUnlock() + + input = strings.ToLower(input) + minDist := 3 // Max distance threshold + var similar string + + for cmd := range v.commandMap { + // Only compare single-word commands + if strings.Contains(cmd, " ") { + continue + } + dist := levenshteinDistance(input, strings.ToLower(cmd)) + if dist < minDist { + minDist = dist + similar = cmd + } + } + + return similar +} + +// GetAllCommands returns all registered command names +func (v *CommandValidator) GetAllCommands() []string { + v.mu.RLock() + defer v.mu.RUnlock() + + commands := make([]string, 0, len(v.commandMap)) + for cmd := range v.commandMap { + // Only return single-word commands to avoid duplication + if !strings.Contains(cmd, " ") { + commands = append(commands, cmd) + } + } + sort.Strings(commands) + return commands +} + +// GetAllCommandsWithSubcommands returns all registered commands including subcommands +func (v *CommandValidator) GetAllCommandsWithSubcommands() []string { + v.mu.RLock() + defer v.mu.RUnlock() + + commands := make([]string, 0, len(v.commandMap)) + for cmd := range v.commandMap { + commands = append(commands, cmd) + } + sort.Strings(commands) + return commands +} + +// AddCommand manually adds a command to the validator +func (v *CommandValidator) AddCommand(name string, aliases ...string) { + v.mu.Lock() + defer v.mu.Unlock() + + v.commandMap[name] = true + for _, alias := range aliases { + v.aliases[alias] = name + } +} + +// levenshteinDistance calculates the edit distance between two strings +func levenshteinDistance(s1, s2 string) int { + if len(s1) == 0 { + return len(s2) + } + if len(s2) == 0 { + return len(s1) + } + + // Create matrix + matrix := make([][]int, len(s1)+1) + for i := range matrix { + matrix[i] = make([]int, len(s2)+1) + matrix[i][0] = i + } + for j := range matrix[0] { + matrix[0][j] = j + } + + // Fill matrix + for i := 1; i <= len(s1); i++ { + for j := 1; j <= len(s2); j++ { + cost := 1 + if s1[i-1] == s2[j-1] { + cost = 0 + } + matrix[i][j] = min( + matrix[i-1][j]+1, // deletion + matrix[i][j-1]+1, // insertion + matrix[i-1][j-1]+cost, // substitution + ) + } + } + + return matrix[len(s1)][len(s2)] +} + +func min(a, b, c int) int { + if a < b { + if a < c { + return a + } + return c + } + if b < c { + return b + } + return c +} diff --git a/client/core/command_validator_test.go b/client/core/command_validator_test.go new file mode 100644 index 00000000..f54231f0 --- /dev/null +++ b/client/core/command_validator_test.go @@ -0,0 +1,126 @@ +package core + +import ( + "sort" + "testing" + + "github.com/spf13/cobra" +) + +func TestGetCommandsForMenu(t *testing.T) { + // Create mock client commands + clientRoot := &cobra.Command{Use: "client"} + clientRoot.AddCommand( + &cobra.Command{Use: "wizard", Short: "wizard command"}, + &cobra.Command{Use: "website", Short: "website command"}, + &cobra.Command{Use: "listener", Short: "listener command"}, + ) + + // Create mock implant commands + implantRoot := &cobra.Command{Use: "implant"} + implantRoot.AddCommand( + &cobra.Command{Use: "whoami", Short: "whoami command"}, + &cobra.Command{Use: "wmi_query", Short: "wmi query command"}, + &cobra.Command{Use: "ps", Short: "process list"}, + ) + + // Create validator with client menu + v := NewCommandValidatorWithMenu(clientRoot, "client") + // Add implant commands + v.AddCommandsFromCobra(implantRoot, "implant") + + // Test: Get client commands + clientCmds := v.GetCommandsForMenu("client") + sort.Strings(clientCmds) + + expectedClient := []string{"client", "listener", "website", "wizard"} + sort.Strings(expectedClient) + + if len(clientCmds) != len(expectedClient) { + t.Errorf("Expected %d client commands, got %d: %v", len(expectedClient), len(clientCmds), clientCmds) + } + + for i, cmd := range expectedClient { + if clientCmds[i] != cmd { + t.Errorf("Expected client command %q, got %q", cmd, clientCmds[i]) + } + } + + // Test: Get implant commands + implantCmds := v.GetCommandsForMenu("implant") + sort.Strings(implantCmds) + + expectedImplant := []string{"implant", "ps", "whoami", "wmi_query"} + sort.Strings(expectedImplant) + + if len(implantCmds) != len(expectedImplant) { + t.Errorf("Expected %d implant commands, got %d: %v", len(expectedImplant), len(implantCmds), implantCmds) + } + + for i, cmd := range expectedImplant { + if implantCmds[i] != cmd { + t.Errorf("Expected implant command %q, got %q", cmd, implantCmds[i]) + } + } + + // Test: Verify whoami is NOT in client menu + for _, cmd := range clientCmds { + if cmd == "whoami" { + t.Error("whoami should NOT be in client menu commands") + } + } + + // Test: Verify wizard is NOT in implant menu + for _, cmd := range implantCmds { + if cmd == "wizard" { + t.Error("wizard should NOT be in implant menu commands") + } + } + + // Test: Empty menu returns all commands + allCmds := v.GetCommandsForMenu("") + if len(allCmds) != len(expectedClient)+len(expectedImplant) { + t.Errorf("Expected %d total commands, got %d", len(expectedClient)+len(expectedImplant), len(allCmds)) + } +} + +func TestAddCommandWithMenu(t *testing.T) { + v := NewCommandValidator(nil) + + // Add commands with menu context + v.AddCommandWithMenu("client", "sessions", "ss") + v.AddCommandWithMenu("implant", "shell", "sh") + + // Verify client command + clientCmds := v.GetCommandsForMenu("client") + found := false + for _, cmd := range clientCmds { + if cmd == "sessions" { + found = true + break + } + } + if !found { + t.Error("sessions should be in client menu") + } + + // Verify implant command + implantCmds := v.GetCommandsForMenu("implant") + found = false + for _, cmd := range implantCmds { + if cmd == "shell" { + found = true + break + } + } + if !found { + t.Error("shell should be in implant menu") + } + + // Verify shell is NOT in client menu + for _, cmd := range clientCmds { + if cmd == "shell" { + t.Error("shell should NOT be in client menu") + } + } +} diff --git a/client/core/console.go b/client/core/console.go index 491a3aa0..3bd932d7 100644 --- a/client/core/console.go +++ b/client/core/console.go @@ -71,6 +71,11 @@ type Console struct { Helpers map[string]*cobra.Command MalManager *plugin.MalManager + + // AI Completion Engine + aiCompletionEngine *AICompletionEngine + aiCache *AICompletionCache + commandValidator *CommandValidator } func (c *Console) NewConsole() { @@ -89,6 +94,22 @@ func (c *Console) NewConsole() { implant.Prompt().Primary = c.GetPrompt implant.AddInterrupt(io.EOF, repl.ExitImplantMenu) // Ctrl-D implant.AddHistorySourceFile("history", filepath.Join(assets.GetRootAppDir(), "implant_history")) + + // Register AI prediction for next argument (double-tap Tab to accept) + iom.Shell().AIPredictNext = c.handleAIPredictNext + + // Register line hook to handle '?' prefix without space (e.g., '?hello' -> '?' 'hello') + iom.PreCmdRunLineHooks = append(iom.PreCmdRunLineHooks, func(args []string) ([]string, error) { + if len(args) > 0 && len(args[0]) > 1 && strings.HasPrefix(args[0], "?") { + // Split '?xxx' into '?' and 'xxx' + question := args[0][1:] + newArgs := make([]string, 0, len(args)+1) + newArgs = append(newArgs, "?", question) + newArgs = append(newArgs, args[1:]...) + return newArgs, nil + } + return args, nil + }) } func (c *Console) Start(bindCmds ...BindCmds) error { @@ -106,7 +127,10 @@ func (c *Console) Start(bindCmds ...BindCmds) error { c.App.Menu(consts.ClientMenu).Command = bindCmds[0](c)() c.App.Menu(consts.ImplantMenu).Command = bindCmds[1](c)() - // 所有命令注册完成后,安全地启动MCP服务器和Local RPC服务器 + // Initialize AI completion components after commands are registered + c.initAICompletion() + + // After all commands are registered, safely start MCP server and Local RPC server if c.Server != nil { c.InitMCPServer() c.InitLocalRPCServer() @@ -273,3 +297,107 @@ func (c *Console) AddCommandFuncHelper(cmdName string, funcName string, example }) } } + +func (c *Console) GetRecentHistory(limit int) []string { + if limit <= 0 || c == nil || c.App == nil { + return nil + } + + shell := c.App.Shell() + if shell == nil || shell.History == nil || shell.History.Current() == nil { + return nil + } + + hist := shell.History.Current() + count := hist.Len() + start := count - limit + if start < 0 { + start = 0 + } + + capacity := limit + if count-start < capacity { + capacity = count - start + } + history := make([]string, 0, capacity) + for i := start; i < count; i++ { + if line, err := hist.GetLine(i); err == nil && line != "" { + history = append(history, line) + } + } + + if len(history) > limit { + history = history[len(history)-limit:] + } + + return history +} + +func getValidAISettings() (*assets.AISettings, error) { + settings, err := assets.GetSetting() + if err != nil { + return nil, fmt.Errorf("failed to load settings: %w", err) + } + if settings == nil || settings.AI == nil || !settings.AI.Enable { + return nil, fmt.Errorf("AI not enabled. Use 'ai-config --enable --api-key ' to enable it") + } + if settings.AI.APIKey == "" { + return nil, fmt.Errorf("AI API key not configured. Use 'ai-config --api-key ' to set it") + } + + return settings.AI, nil +} + +// initAICompletion initializes the AI completion engine +func (c *Console) initAICompletion() { + settings, err := assets.GetSetting() + if err != nil || settings == nil || settings.AI == nil || !settings.AI.Enable { + return + } + + // Initialize cache (500 entries, 30 minute TTL) + c.aiCache = NewAICompletionCache(500, 30*time.Minute) + + // Initialize command validator from client menu + clientMenu := c.App.Menu(consts.ClientMenu).Command + if clientMenu != nil { + c.commandValidator = NewCommandValidatorWithMenu(clientMenu, consts.ClientMenu) + } + + // Also add implant menu commands to validator + implantMenu := c.App.Menu(consts.ImplantMenu).Command + if implantMenu != nil && c.commandValidator != nil { + c.commandValidator.AddCommandsFromCobra(implantMenu, consts.ImplantMenu) + } + + // Initialize AI client + aiClient := NewAIClient(settings.AI) + + // Create completion engine + c.aiCompletionEngine = NewAICompletionEngine(aiClient, c.aiCache, c.commandValidator) +} + +// handleAIPredictNext handles AI prediction for the next argument (double-tap Tab to accept) +func (c *Console) handleAIPredictNext(line string, history []string) (string, error) { + // Lazy initialize if not done yet + if c.aiCompletionEngine == nil { + c.initAICompletion() + } + + if c.aiCompletionEngine == nil { + return "", fmt.Errorf("AI prediction not available") + } + + // Use 2 second timeout for fast prediction + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + activeMenu := "" + if c != nil && c.App != nil { + if m := c.App.ActiveMenu(); m != nil { + activeMenu = m.Name() + } + } + + return c.aiCompletionEngine.PredictNextArgument(ctx, line, history, activeMenu) +} diff --git a/client/core/localrpc_test.go b/client/core/localrpc_test.go index 72d1508a..a6ba5ec0 100644 --- a/client/core/localrpc_test.go +++ b/client/core/localrpc_test.go @@ -3,7 +3,9 @@ package core import ( "context" "encoding/json" + "net" "testing" + "time" "github.com/chainreactors/IoM-go/proto/services/localrpc" "github.com/chainreactors/malice-network/client/plugin" @@ -18,9 +20,26 @@ const ( // setupRPCClient creates a gRPC client connection to the test RPC server func setupRPCClient(t *testing.T) (localrpc.CommandServiceClient, *grpc.ClientConn) { - conn, err := grpc.Dial(testRPCAddr, grpc.WithTransportCredentials(insecure.NewCredentials())) + t.Helper() + + // These are integration tests; skip when no local RPC server is running. + if c, err := net.DialTimeout("tcp", testRPCAddr, 250*time.Millisecond); err != nil { + t.Skipf("Skipping: local RPC server not reachable at %s: %v", testRPCAddr, err) + } else { + _ = c.Close() + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + conn, err := grpc.DialContext( + ctx, + testRPCAddr, + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithBlock(), + ) if err != nil { - t.Fatalf("Failed to connect to RPC server at %s: %v", testRPCAddr, err) + t.Skipf("Skipping: failed to connect to RPC server at %s: %v", testRPCAddr, err) } client := localrpc.NewCommandServiceClient(conn) diff --git a/client/plugin/manager.go b/client/plugin/manager.go index 68fc8679..78440e97 100644 --- a/client/plugin/manager.go +++ b/client/plugin/manager.go @@ -82,6 +82,11 @@ func (mm *MalManager) loadEmbeddedMals() { continue } + // Auto-register wizard templates from embedded resources before running plugin init scripts. + if n := registerWizardTemplatesFromEmbedFS(embedPlugin.Name, embedPlugin.RootPath, embedPlugin.FS); n > 0 { + logs.Log.Debugf("Registered %d wizard templates from embedded plugin %s\n", n, embedPlugin.Name) + } + // 运行插件 if err := embedPlugin.Run(); err != nil { logs.Log.Errorf("Failed to run embedded plugin %s: %v\n", levelName, err) @@ -105,6 +110,11 @@ func (mm *MalManager) loadEmbeddedMals() { func (mm *MalManager) loadExternalMals() { mm.globalPlugins = LoadGlobalLuaPlugin() + for _, plug := range mm.globalPlugins { + if n := registerWizardTemplatesFromDisk(plug.Name, plug.Path); n > 0 { + logs.Log.Debugf("Registered %d wizard templates from global plugin %s\n", n, plug.Name) + } + } for _, manifest := range GetPluginManifest() { _, err := mm.LoadExternalMal(manifest) @@ -138,6 +148,11 @@ func (mm *MalManager) LoadExternalMal(manifest *MalManiFest) (Plugin, error) { return nil, err } + // Auto-register wizard templates from plugin resources before running plugin init scripts. + if n := registerWizardTemplatesFromDisk(manifest.Name, filepath.Join(assets.GetMalsDir(), manifest.Name)); n > 0 { + logs.Log.Debugf("Registered %d wizard templates from external plugin %s\n", n, manifest.Name) + } + err = plugin.Run() if err != nil { return nil, err @@ -249,6 +264,10 @@ func (mm *MalManager) ReloadExternalMal(malName string) error { return fmt.Errorf("failed to create new plugin %s: %v", malName, err) } + if n := registerWizardTemplatesFromDisk(manifest.Name, filepath.Join(assets.GetMalsDir(), manifest.Name)); n > 0 { + logs.Log.Debugf("Registered %d wizard templates from reloaded plugin %s\n", n, manifest.Name) + } + err = newPlugin.Run() if err != nil { return fmt.Errorf("failed to run new plugin %s: %v", malName, err) diff --git a/client/plugin/vm.go b/client/plugin/vm.go index e5ff26cc..15ea72ab 100644 --- a/client/plugin/vm.go +++ b/client/plugin/vm.go @@ -2,14 +2,16 @@ package plugin import ( "fmt" + "strings" + "sync" + "time" + "github.com/chainreactors/logs" + "github.com/chainreactors/malice-network/client/wizard" "github.com/chainreactors/malice-network/helper/intermediate" "github.com/chainreactors/mals" lua "github.com/yuin/gopher-lua" "github.com/yuin/gopher-lua/parse" - "strings" - "sync" - "time" ) func NewLuaVM() *lua.LState { @@ -24,9 +26,20 @@ func NewLuaVM() *lua.LState { } // 注册所有内置函数 + // Ensure wizard functions are part of builtin definitions/help generation. + wizard.RegisterBuiltinFunctions() for name, fun := range intermediate.InternalFunctions.Package(intermediate.BuiltinPackage) { vm.SetGlobal(name, vm.NewFunction(mals.WrapFuncForLua(fun))) } + + // Setup wizard metatable and functions + wizard.SetupMetatable(vm) + wizardFns := make(map[string]lua.LGFunction) + wizard.RegisterLuaFunctions(wizardFns) + for name, fn := range wizardFns { + vm.SetGlobal(name, vm.NewFunction(fn)) + } + return vm } diff --git a/client/plugin/wizard_templates.go b/client/plugin/wizard_templates.go new file mode 100644 index 00000000..e355ab5c --- /dev/null +++ b/client/plugin/wizard_templates.go @@ -0,0 +1,105 @@ +package plugin + +import ( + "io/fs" + "os" + "path" + "path/filepath" + "strings" + + "github.com/chainreactors/logs" + "github.com/chainreactors/malice-network/client/wizard" +) + +const wizardSpecDir = "resources/wizards" + +// specWalker abstracts the file walking logic for different file systems +type specWalker struct { + pluginName string + registered int +} + +func (w *specWalker) processSpec(specPath, rel string, loadFn func(string) (*wizard.WizardSpec, error)) { + relNoExt := strings.TrimSuffix(rel, path.Ext(rel)) + if !isWizardSpecPath(rel) { + return + } + + spec, err := loadFn(specPath) + if err != nil { + logs.Log.Warnf("Failed to load wizard spec %s: %v\n", specPath, err) + return + } + + templateName := makeWizardTemplateName(w.pluginName, spec, relNoExt) + if err := wizard.RegisterTemplateFromSpec(templateName, spec); err != nil { + logs.Log.Warnf("Failed to register wizard template %s from %s: %v\n", templateName, specPath, err) + return + } + w.registered++ +} + +func registerWizardTemplatesFromEmbedFS(pluginName, pluginRoot string, f fs.FS) int { + root := path.Join(pluginRoot, wizardSpecDir) + if _, err := fs.Stat(f, root); err != nil { + return 0 + } + + w := &specWalker{pluginName: pluginName} + _ = fs.WalkDir(f, root, func(p string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() { + return nil + } + rel := strings.TrimPrefix(p, root+"/") + w.processSpec("embed://"+p, rel, wizard.LoadSpec) + return nil + }) + return w.registered +} + +func registerWizardTemplatesFromDisk(pluginName, pluginPath string) int { + root := filepath.Join(pluginPath, wizardSpecDir) + if info, err := os.Stat(root); err != nil || info == nil || !info.IsDir() { + return 0 + } + + w := &specWalker{pluginName: pluginName} + _ = filepath.WalkDir(root, func(p string, d fs.DirEntry, err error) error { + if err != nil || d.IsDir() { + return nil + } + rel, err := filepath.Rel(root, p) + if err != nil { + return nil + } + w.processSpec(p, filepath.ToSlash(rel), wizard.LoadSpec) + return nil + }) + return w.registered +} + +func isWizardSpecPath(p string) bool { + switch strings.ToLower(path.Ext(p)) { + case ".yaml", ".yml", ".json": + return true + default: + return false + } +} + +func makeWizardTemplateName(pluginName string, spec *wizard.WizardSpec, fallback string) string { + base := strings.TrimSpace(fallback) + if spec != nil && strings.TrimSpace(spec.ID) != "" { + base = strings.TrimSpace(spec.ID) + } + if base == "" { + base = "wizard" + } + base = strings.TrimPrefix(strings.TrimPrefix(base, "/"), "./") + + prefix := strings.TrimSpace(pluginName) + if prefix == "" || strings.HasPrefix(base, prefix+":") { + return base + } + return prefix + ":" + base +} diff --git a/client/plugin/wizard_templates_test.go b/client/plugin/wizard_templates_test.go new file mode 100644 index 00000000..8bef6836 --- /dev/null +++ b/client/plugin/wizard_templates_test.go @@ -0,0 +1,45 @@ +package plugin + +import ( + "os" + "path/filepath" + "testing" + + wizardfw "github.com/chainreactors/malice-network/client/wizard" +) + +func TestRegisterWizardTemplatesFromDisk(t *testing.T) { + tmp := t.TempDir() + wizDir := filepath.Join(tmp, wizardSpecDir) + if err := os.MkdirAll(wizDir, 0o755); err != nil { + t.Fatalf("mkdir wizards dir: %v", err) + } + + specPath := filepath.Join(wizDir, "priv_esc.yaml") + if err := os.WriteFile(specPath, []byte(` +title: Privilege Escalation +fields: + - name: method + title: Method + type: select + options: [uac, token] +`), 0o644); err != nil { + t.Fatalf("write spec: %v", err) + } + + pluginName := "testplug" + if n := registerWizardTemplatesFromDisk(pluginName, tmp); n != 1 { + t.Fatalf("expected 1 registered template, got %d", n) + } + + wiz, ok := wizardfw.GetTemplate("testplug:priv_esc") + if !ok || wiz == nil { + t.Fatalf("expected registered template testplug:priv_esc") + } + if wiz.Title == "" { + t.Fatalf("expected non-empty title") + } + if len(wiz.Fields) != 1 { + t.Fatalf("expected 1 field, got %d", len(wiz.Fields)) + } +} diff --git a/client/wizard/cobra.go b/client/wizard/cobra.go new file mode 100644 index 00000000..7a1388ed --- /dev/null +++ b/client/wizard/cobra.go @@ -0,0 +1,514 @@ +package wizard + +import ( + "bytes" + "encoding/csv" + "fmt" + "sort" + "strconv" + "strings" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" +) + +// skipFlags defines flags that should not appear in wizard forms +var skipFlags = map[string]bool{ + "help": true, + "wizard": true, + "version": true, +} + +// CobraToWizard converts a cobra.Command's flags to a Wizard +// It supports grouping via ui:group annotations and edit mode (reading current flag values) +func CobraToWizard(cmd *cobra.Command) *Wizard { + wiz := NewWizard(wizardIDFromCommand(cmd), cmd.Short) + if cmd.Long != "" { + wiz.WithDescription(cmd.Long) + } + + // Collect all flags and group them + groups := make(map[string][]*pflag.Flag) + var ungroupedFlags []*pflag.Flag + groupOrder := make([]string, 0) + groupOrderSet := make(map[string]bool) + + cmd.Flags().VisitAll(func(flag *pflag.Flag) { + if shouldSkipFlag(flag) { + return + } + + groupName := getFlagGroup(flag) + if groupName != "" { + groups[groupName] = append(groups[groupName], flag) + if !groupOrderSet[groupName] { + groupOrder = append(groupOrder, groupName) + groupOrderSet[groupName] = true + } + } else { + ungroupedFlags = append(ungroupedFlags, flag) + } + }) + + // Sort groups by order annotation if available + sort.SliceStable(groupOrder, func(i, j int) bool { + orderI := getGroupOrder(groupOrder[i], groups[groupOrder[i]]) + orderJ := getGroupOrder(groupOrder[j], groups[groupOrder[j]]) + return orderI < orderJ + }) + + // If there are groups, use grouped mode + if len(groups) > 0 { + usedGroupIDs := make(map[string]bool) + + // Add ungrouped fields as "Basic" group first + if len(ungroupedFlags) > 0 { + sortFlagsByOrder(ungroupedFlags) + basicGroup := wiz.NewGroup(uniqueGroupID(usedGroupIDs, "general"), "General") + for _, flag := range ungroupedFlags { + field := flagToWizardField(flag) + basicGroup.AddField(field) + } + } + + // Add other groups in order + for _, groupName := range groupOrder { + flags := groups[groupName] + group := wiz.NewGroup(uniqueGroupID(usedGroupIDs, sanitizeGroupName(groupName)), groupName) + + // Sort flags within group by order annotation + sortFlagsByOrder(flags) + + for _, flag := range flags { + field := flagToWizardField(flag) + group.AddField(field) + } + } + } else { + // Flat mode: all fields without grouping + sortFlagsByOrder(ungroupedFlags) + for _, flag := range ungroupedFlags { + field := flagToWizardField(flag) + wiz.AddField(field) + } + } + + return wiz +} + +// shouldSkipFlag determines if a flag should be excluded from wizard +func shouldSkipFlag(flag *pflag.Flag) bool { + if skipFlags[flag.Name] { + return true + } + // Skip hidden flags + if flag.Hidden { + return true + } + return false +} + +// getFlagGroup extracts the group name from flag annotations +func getFlagGroup(flag *pflag.Flag) string { + if flag.Annotations == nil { + return "" + } + // Check ui:group annotation + if groups, ok := flag.Annotations["ui:group"]; ok && len(groups) > 0 { + return groups[0] + } + // Check group annotation (alternative) + if groups, ok := flag.Annotations["group"]; ok && len(groups) > 0 { + return groups[0] + } + return "" +} + +// sanitizeGroupName converts a display name to a valid identifier +func sanitizeGroupName(name string) string { + // Replace spaces and special chars with underscores + result := strings.ToLower(name) + result = strings.ReplaceAll(result, " ", "_") + result = strings.ReplaceAll(result, "-", "_") + return result +} + +func uniqueGroupID(used map[string]bool, base string) string { + idBase := sanitizeGroupName(strings.TrimSpace(base)) + if idBase == "" { + idBase = "group" + } + id := idBase + if !used[id] { + used[id] = true + return id + } + for i := 2; ; i++ { + id = fmt.Sprintf("%s_%d", idBase, i) + if !used[id] { + used[id] = true + return id + } + } +} + +func wizardIDFromCommand(cmd *cobra.Command) string { + if cmd == nil { + return "" + } + path := strings.TrimSpace(cmd.CommandPath()) + if path == "" { + path = cmd.Name() + } + // Keep it stable + filesystem-ish. + id := strings.ToLower(path) + id = strings.ReplaceAll(id, " ", "_") + id = strings.ReplaceAll(id, "-", "_") + id = strings.ReplaceAll(id, "/", "_") + return id +} + +// getGroupOrder returns the order value for a group (based on first flag's order) +func getGroupOrder(groupName string, flags []*pflag.Flag) int { + minOrder := 9999 + for _, flag := range flags { + if order := getFlagOrder(flag); order < minOrder { + minOrder = order + } + } + return minOrder +} + +// getFlagOrder gets the order value from flag annotations +func getFlagOrder(flag *pflag.Flag) int { + if flag.Annotations == nil { + return 9999 + } + if orders, ok := flag.Annotations["ui:order"]; ok && len(orders) > 0 { + if order, err := strconv.Atoi(orders[0]); err == nil { + return order + } + } + return 9999 +} + +// sortFlagsByOrder sorts flags by their ui:order annotation +func sortFlagsByOrder(flags []*pflag.Flag) { + sort.SliceStable(flags, func(i, j int) bool { + return getFlagOrder(flags[i]) < getFlagOrder(flags[j]) + }) +} + +// flagToWizardField converts a single pflag.Flag to WizardField +func flagToWizardField(flag *pflag.Flag) *WizardField { + field := &WizardField{ + Name: flag.Name, + Title: flag.Name, + Description: flag.Usage, + Required: isFlagRequired(flag), + } + + // Get current value for edit mode support + currentValue := flag.Value.String() + + // Determine field type and default value based on flag type + switch flag.Value.Type() { + case "bool": + field.Type = FieldConfirm + field.Default = currentValue == "true" + + case "int", "int8", "int16", "int32", "int64": + field.Type = FieldNumber + if val, err := strconv.ParseInt(currentValue, 10, 64); err == nil { + field.Default = int(val) + } else { + field.Default = 0 + } + + case "uint", "uint8", "uint16", "uint32", "uint64": + field.Type = FieldNumber + if val, err := strconv.ParseUint(currentValue, 10, 64); err == nil { + field.Default = int(val) + } else { + field.Default = 0 + } + + case "float32", "float64": + // Wizard "number" currently supports ints only; treat floats as string input with float validation. + field.Type = FieldInput + field.Default = currentValue + field.Validate = floatValidatorFromFlag(flag) + + default: + // pflag slice types stringify as "[...]" and Set() may append when already-changed. + // Treat slices as comma-separated string input and apply via SliceValue.Replace. + if sv, ok := flag.Value.(pflag.SliceValue); ok { + field.Type = FieldInput + field.Default = formatCSV(sv.GetSlice()) + field.Description = flag.Usage + " (comma-separated values)" + field.Validate = csvListValidator() + break + } + + field.Type = FieldInput + field.Default = currentValue + + // Check if should use textarea + if widget := getFlagWidget(flag); widget == "textarea" { + field.Type = FieldText + } + } + + // Check for enum options (converts to Select) + if options := getFlagOptions(flag); len(options) > 0 { + field.Type = FieldSelect + field.Options = options + // Auto-select first non-empty option if current value is empty + if currentValue == "" || currentValue == "(empty)" { + for _, opt := range options { + if opt != "" && opt != "(empty)" { + field.Default = opt + break + } + } + } else { + field.Default = currentValue + } + } + + return field +} + +// isFlagRequired determines if a flag is required +func isFlagRequired(flag *pflag.Flag) bool { + if flag.Annotations == nil { + return false + } + + // Check ui:required annotation + if required, ok := flag.Annotations["ui:required"]; ok && len(required) > 0 { + return required[0] == "true" + } + + // Check cobra's required flag annotation (set by MarkFlagRequired) + if required, ok := flag.Annotations["cobra_annotation_bash_completion_one_required_flag"]; ok { + return len(required) > 0 + } + + return false +} + +// getFlagWidget gets the widget type from flag annotations +func getFlagWidget(flag *pflag.Flag) string { + if flag.Annotations == nil { + return "" + } + if widget, ok := flag.Annotations["ui:widget"]; ok && len(widget) > 0 { + return widget[0] + } + return "" +} + +// getFlagOptions gets the options list from flag annotations +func getFlagOptions(flag *pflag.Flag) []string { + if flag.Annotations == nil { + return nil + } + if options, ok := flag.Annotations["ui:options"]; ok && len(options) > 0 { + return options + } + return nil +} + +func csvListValidator() func(string) error { + return func(s string) error { + _, err := parseCSVList(s) + if err != nil { + return err + } + return nil + } +} + +func floatValidatorFromFlag(flag *pflag.Flag) func(string) error { + min, max, ok := getFlagFloatRange(flag) + if ok { + return ValidateFloat(min, max) + } + return func(val string) error { + if strings.TrimSpace(val) == "" { + return nil + } + if _, err := strconv.ParseFloat(val, 64); err != nil { + return fmt.Errorf("invalid number: %s", val) + } + return nil + } +} + +func getFlagFloatRange(flag *pflag.Flag) (min, max float64, ok bool) { + if flag == nil || flag.Annotations == nil { + return 0, 0, false + } + mins, okMin := flag.Annotations["ui:min"] + maxs, okMax := flag.Annotations["ui:max"] + if !okMin || !okMax || len(mins) == 0 || len(maxs) == 0 { + return 0, 0, false + } + min, err1 := strconv.ParseFloat(strings.TrimSpace(mins[0]), 64) + max, err2 := strconv.ParseFloat(strings.TrimSpace(maxs[0]), 64) + if err1 != nil || err2 != nil { + return 0, 0, false + } + return min, max, true +} + +func formatCSV(vals []string) string { + b := &bytes.Buffer{} + w := csv.NewWriter(b) + _ = w.Write(vals) + w.Flush() + return strings.TrimSuffix(b.String(), "\n") +} + +func parseCSVList(s string) ([]string, error) { + s = strings.TrimSpace(s) + if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") && len(s) >= 2 { + s = strings.TrimSpace(s[1 : len(s)-1]) + } + if s == "" { + return []string{}, nil + } + r := csv.NewReader(strings.NewReader(s)) + r.FieldsPerRecord = -1 + records, err := r.Read() + if err != nil { + return nil, err + } + return records, nil +} + +// ApplyWizardResultToFlags applies wizard results back to cobra.Command flags +func ApplyWizardResultToFlags(cmd *cobra.Command, result *WizardResult) error { + for name, value := range result.ToMap() { + // Look up flag in command flags + flag := cmd.Flags().Lookup(name) + if flag == nil { + // If flags were not merged (e.g., helper calls outside cobra execution), try persistent flags too. + flag = cmd.PersistentFlags().Lookup(name) + } + if flag == nil { + // Skip unknown fields + continue + } + + changed, err := applyWizardValueToFlag(flag, value) + if err != nil { + return fmt.Errorf("failed to set flag %s: %w", name, err) + } + if changed { + flag.Changed = true + } + } + return nil +} + +func applyWizardValueToFlag(flag *pflag.Flag, value any) (bool, error) { + // Slice flags: always replace, never Set() (Set() may append when already-changed). + if sv, ok := flag.Value.(pflag.SliceValue); ok { + desired, err := coerceStringSlice(value) + if err != nil { + return false, err + } + if stringSliceEqual(sv.GetSlice(), desired) { + return false, nil + } + if err := sv.Replace(desired); err != nil { + return false, err + } + return true, nil + } + + desiredStr := coerceString(value) + + // Best-effort semantic equality to avoid flipping Flag.Changed when user accepted defaults. + switch flag.Value.Type() { + case "bool": + cur, err1 := strconv.ParseBool(flag.Value.String()) + des, err2 := strconv.ParseBool(desiredStr) + if err1 == nil && err2 == nil && cur == des { + return false, nil + } + case "int", "int8", "int16", "int32", "int64": + cur, err1 := strconv.ParseInt(flag.Value.String(), 10, 64) + des, err2 := strconv.ParseInt(strings.TrimSpace(desiredStr), 10, 64) + if err1 == nil && err2 == nil && cur == des { + return false, nil + } + case "uint", "uint8", "uint16", "uint32", "uint64": + cur, err1 := strconv.ParseUint(flag.Value.String(), 10, 64) + des, err2 := strconv.ParseUint(strings.TrimSpace(desiredStr), 10, 64) + if err1 == nil && err2 == nil && cur == des { + return false, nil + } + case "float32", "float64": + cur, err1 := strconv.ParseFloat(flag.Value.String(), 64) + des, err2 := strconv.ParseFloat(strings.TrimSpace(desiredStr), 64) + if err1 == nil && err2 == nil && cur == des { + return false, nil + } + default: + if flag.Value.String() == desiredStr { + return false, nil + } + } + + if err := flag.Value.Set(desiredStr); err != nil { + return false, err + } + return true, nil +} + +func coerceString(v any) string { + switch val := v.(type) { + case bool: + return strconv.FormatBool(val) + case int: + return strconv.Itoa(val) + case int64: + return strconv.FormatInt(val, 10) + case float64: + return strconv.FormatFloat(val, 'f', -1, 64) + case []string: + return strings.Join(val, ",") + case string: + return val + default: + return fmt.Sprintf("%v", v) + } +} + +func coerceStringSlice(v any) ([]string, error) { + switch val := v.(type) { + case []string: + out := make([]string, len(val)) + copy(out, val) + return out, nil + case string: + return parseCSVList(val) + default: + return parseCSVList(coerceString(v)) + } +} + +func stringSliceEqual(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/client/wizard/cobra_test.go b/client/wizard/cobra_test.go new file mode 100644 index 00000000..adb3bdd9 --- /dev/null +++ b/client/wizard/cobra_test.go @@ -0,0 +1,255 @@ +package wizard + +import ( + "testing" + + "github.com/spf13/cobra" +) + +func TestCobraToWizard(t *testing.T) { + // Create a test command with various flag types + cmd := &cobra.Command{ + Use: "test", + Short: "Test command", + Long: "This is a test command for wizard conversion", + } + + // Add flags of different types + cmd.Flags().String("name", "", "Name of the target") + cmd.Flags().Int("port", 8080, "Port number") + cmd.Flags().Bool("verbose", false, "Enable verbose output") + cmd.Flags().StringSlice("tags", nil, "Tags for the target") + cmd.Flags().Float64("timeout", 30.0, "Timeout in seconds") + + // Mark some as required + cmd.MarkFlagRequired("name") + + // Add group annotations + nameFlag := cmd.Flags().Lookup("name") + nameFlag.Annotations = map[string][]string{ + "ui:group": {"Basic"}, + "ui:order": {"1"}, + } + + portFlag := cmd.Flags().Lookup("port") + portFlag.Annotations = map[string][]string{ + "ui:group": {"Network"}, + "ui:order": {"2"}, + } + + // Convert to wizard + wiz := CobraToWizard(cmd) + + // Verify wizard was created + if wiz == nil { + t.Fatal("CobraToWizard returned nil") + } + + // Verify title + if wiz.Title != "Test command" { + t.Errorf("Expected title 'Test command', got '%s'", wiz.Title) + } + + // Verify description + if wiz.Description != "This is a test command for wizard conversion" { + t.Errorf("Expected description to be set, got '%s'", wiz.Description) + } + + // Verify groups were created (Basic, Network, and "General" for ungrouped) + if len(wiz.Groups) == 0 { + t.Error("Expected groups to be created") + } + + // Verify fields were converted + if len(wiz.Fields) != 5 { // name, port, verbose, tags, timeout (help is skipped) + t.Errorf("Expected 5 fields, got %d", len(wiz.Fields)) + } + + // Verify field types + nameField := wiz.GetField("name") + if nameField == nil { + t.Fatal("Expected 'name' field to exist") + } + if nameField.Type != FieldInput { + t.Errorf("Expected name field to be FieldInput, got %d", nameField.Type) + } + + portField := wiz.GetField("port") + if portField == nil { + t.Fatal("Expected 'port' field to exist") + } + if portField.Type != FieldNumber { + t.Errorf("Expected port field to be FieldNumber, got %d", portField.Type) + } + if portField.Default != 8080 { + t.Errorf("Expected port default to be 8080, got %v", portField.Default) + } + + verboseField := wiz.GetField("verbose") + if verboseField == nil { + t.Fatal("Expected 'verbose' field to exist") + } + if verboseField.Type != FieldConfirm { + t.Errorf("Expected verbose field to be FieldConfirm, got %d", verboseField.Type) + } + + timeoutField := wiz.GetField("timeout") + if timeoutField == nil { + t.Fatal("Expected 'timeout' field to exist") + } + if timeoutField.Type != FieldInput { + t.Errorf("Expected timeout field to be FieldInput (float), got %d", timeoutField.Type) + } + if def, ok := timeoutField.Default.(string); !ok || def == "" { + t.Errorf("Expected timeout default to be a non-empty string, got %T/%v", timeoutField.Default, timeoutField.Default) + } + + tagsField := wiz.GetField("tags") + if tagsField == nil { + t.Fatal("Expected 'tags' field to exist") + } + if tagsField.Type != FieldInput { + t.Errorf("Expected tags field to be FieldInput (slice), got %d", tagsField.Type) + } + if def, ok := tagsField.Default.(string); !ok || def != "" { + t.Errorf("Expected tags default to be empty string, got %T/%v", tagsField.Default, tagsField.Default) + } +} + +func TestCobraToWizardWithOptions(t *testing.T) { + cmd := &cobra.Command{ + Use: "select-test", + Short: "Test select options", + } + + cmd.Flags().String("format", "json", "Output format") + + formatFlag := cmd.Flags().Lookup("format") + formatFlag.Annotations = map[string][]string{ + "ui:options": {"json", "yaml", "xml"}, + } + + wiz := CobraToWizard(cmd) + if wiz == nil { + t.Fatal("CobraToWizard returned nil") + } + + formatField := wiz.GetField("format") + if formatField == nil { + t.Fatal("Expected 'format' field to exist") + } + + if formatField.Type != FieldSelect { + t.Errorf("Expected format field to be FieldSelect, got %d", formatField.Type) + } + + if len(formatField.Options) != 3 { + t.Errorf("Expected 3 options, got %d", len(formatField.Options)) + } +} + +func TestApplyWizardResultToFlags(t *testing.T) { + cmd := &cobra.Command{ + Use: "apply-test", + } + + cmd.Flags().String("name", "", "Name") + cmd.Flags().Int("port", 0, "Port") + cmd.Flags().Bool("verbose", false, "Verbose") + + result := NewWizardResult("test") + result.Set("name", "test-value") + result.Set("port", 9090) + result.Set("verbose", true) + + err := ApplyWizardResultToFlags(cmd, result) + if err != nil { + t.Fatalf("ApplyWizardResultToFlags failed: %v", err) + } + + // Verify values were applied + nameVal, _ := cmd.Flags().GetString("name") + if nameVal != "test-value" { + t.Errorf("Expected name to be 'test-value', got '%s'", nameVal) + } + + portVal, _ := cmd.Flags().GetInt("port") + if portVal != 9090 { + t.Errorf("Expected port to be 9090, got %d", portVal) + } + + verboseVal, _ := cmd.Flags().GetBool("verbose") + if !verboseVal { + t.Error("Expected verbose to be true") + } +} + +func TestApplyWizardResultToFlags_UnchangedDoesNotFlipChanged(t *testing.T) { + cmd := &cobra.Command{ + Use: "apply-unchanged-test", + } + + cmd.Flags().Int("port", 8080, "Port") + + result := NewWizardResult("test") + result.Set("port", 8080) + + err := ApplyWizardResultToFlags(cmd, result) + if err != nil { + t.Fatalf("ApplyWizardResultToFlags failed: %v", err) + } + + if cmd.Flags().Changed("port") { + t.Fatalf("expected port flag to remain unchanged when wizard value equals current") + } +} + +func TestApplyWizardResultToFlags_SliceReplaceNotAppend(t *testing.T) { + cmd := &cobra.Command{ + Use: "apply-slice-test", + } + + cmd.Flags().StringSlice("tags", []string{}, "Tags") + _ = cmd.Flags().Set("tags", "a") // simulate CLI-provided value (flag becomes changed) + + result := NewWizardResult("test") + result.Set("tags", "b,c") + + err := ApplyWizardResultToFlags(cmd, result) + if err != nil { + t.Fatalf("ApplyWizardResultToFlags failed: %v", err) + } + + tags, _ := cmd.Flags().GetStringSlice("tags") + if len(tags) != 2 || tags[0] != "b" || tags[1] != "c" { + t.Fatalf("expected tags to be replaced with [b c], got %#v", tags) + } +} + +func TestSkipFlags(t *testing.T) { + cmd := &cobra.Command{ + Use: "skip-test", + } + + cmd.Flags().String("name", "", "Name") + cmd.Flags().Bool("help", false, "Help") // Should be skipped + cmd.Flags().Bool("wizard", false, "Wizard") // Should be skipped + + wiz := CobraToWizard(cmd) + if wiz == nil { + t.Fatal("CobraToWizard returned nil") + } + + // Should only have 'name' field + if len(wiz.Fields) != 1 { + t.Errorf("Expected 1 field, got %d", len(wiz.Fields)) + } + + if wiz.GetField("help") != nil { + t.Error("Expected 'help' field to be skipped") + } + + if wiz.GetField("wizard") != nil { + t.Error("Expected 'wizard' field to be skipped") + } +} diff --git a/client/wizard/executor.go b/client/wizard/executor.go new file mode 100644 index 00000000..1d7681d9 --- /dev/null +++ b/client/wizard/executor.go @@ -0,0 +1,134 @@ +package wizard + +import ( + "fmt" + "net" + "strconv" + "strings" +) + +// Common validators for wizard fields + +// ValidateRequired returns a validator that checks if value is non-empty +func ValidateRequired(fieldName string) func(string) error { + return func(val string) error { + if strings.TrimSpace(val) == "" { + return fmt.Errorf("%s is required", fieldName) + } + return nil + } +} + +// ValidatePort returns a validator that checks if value is a valid port number +func ValidatePort() func(string) error { + return func(val string) error { + if val == "" { + return nil // Allow empty for optional fields + } + port, err := strconv.Atoi(val) + if err != nil { + return fmt.Errorf("invalid port number: %s", val) + } + if port < 1 || port > 65535 { + return fmt.Errorf("port must be between 1 and 65535, got %d", port) + } + return nil + } +} + +// ValidateHost returns a validator that checks if value is a valid host (IP or hostname) +// This is more permissive than ValidateIP and allows: +// - IPv4 addresses (e.g., "192.168.1.1", "0.0.0.0") +// - IPv6 addresses (e.g., "::1", "fe80::1") +// - Hostnames (e.g., "localhost", "example.com") +func ValidateHost() func(string) error { + return func(val string) error { + if val == "" { + return nil // Allow empty for optional fields + } + // Try parsing as IP first + if ip := net.ParseIP(val); ip != nil { + return nil + } + // Otherwise accept any non-empty string as hostname + // (actual DNS resolution happens at connection time) + return nil + } +} + +// ValidateIP returns a validator that checks if value is a valid IP address (IPv4 or IPv6) +func ValidateIP() func(string) error { + return func(val string) error { + if val == "" { + return nil // Allow empty for optional fields + } + if ip := net.ParseIP(val); ip == nil { + return fmt.Errorf("invalid IP address: %s", val) + } + return nil + } +} + +// ValidateRange returns a validator that checks if numeric value is in range +func ValidateRange(min, max int) func(string) error { + return func(val string) error { + if val == "" { + return nil // Allow empty for optional fields + } + num, err := strconv.Atoi(val) + if err != nil { + return fmt.Errorf("invalid number: %s", val) + } + if num < min || num > max { + return fmt.Errorf("value must be between %d and %d, got %d", min, max, num) + } + return nil + } +} + +// ValidateFloat returns a validator that checks if value is a valid float in range +func ValidateFloat(min, max float64) func(string) error { + return func(val string) error { + if val == "" { + return nil + } + num, err := strconv.ParseFloat(val, 64) + if err != nil { + return fmt.Errorf("invalid number: %s", val) + } + if num < min || num > max { + return fmt.Errorf("value must be between %.2f and %.2f, got %.2f", min, max, num) + } + return nil + } +} + +// ValidateOneOf returns a validator that checks if value is one of allowed values +func ValidateOneOf(allowed []string) func(string) error { + return func(val string) error { + if val == "" { + return nil + } + for _, a := range allowed { + if val == a { + return nil + } + } + return fmt.Errorf("value must be one of: %s", strings.Join(allowed, ", ")) + } +} + +// CombineValidators combines multiple validators into one +func CombineValidators(validators ...func(string) error) func(string) error { + return func(val string) error { + for _, v := range validators { + if v == nil { + continue + } + if err := v(val); err != nil { + return err + } + } + return nil + } +} diff --git a/client/wizard/grouped_form.go b/client/wizard/grouped_form.go new file mode 100644 index 00000000..e61b6c90 --- /dev/null +++ b/client/wizard/grouped_form.go @@ -0,0 +1,1000 @@ +package wizard + +import ( + "fmt" + "strconv" + "strings" + "sync" + + tea "github.com/charmbracelet/bubbletea" + "github.com/charmbracelet/huh" + "github.com/charmbracelet/lipgloss" +) + +var ( + // lipglossInitOnce ensures we only initialize lipgloss background detection once + // to avoid OSC terminal queries that can conflict with readline input handling. + lipglossInitOnce sync.Once +) + +// FieldKind represents the type of field in the form +type FieldKind int + +const ( + KindSelect FieldKind = iota + KindMultiSelect + KindInput + KindConfirm + KindNumber +) + +// Styles - package-level style definitions to avoid recreation +var ( + styleTabActive = lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color("0")). + Background(lipgloss.Color("212")). + Padding(0, 1) + styleTabInactive = lipgloss.NewStyle(). + Foreground(lipgloss.Color("250")). + Padding(0, 1) + styleTabCompleted = lipgloss.NewStyle(). + Foreground(lipgloss.Color("42")). + Padding(0, 1) + styleSeparator = lipgloss.NewStyle().Foreground(lipgloss.Color("240")) + styleError = lipgloss.NewStyle().Foreground(lipgloss.Color("9")).Bold(true) + styleHelp = lipgloss.NewStyle().Foreground(lipgloss.Color("240")) + + styleFocusedTitle = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("212")) + styleNormalTitle = lipgloss.NewStyle().Foreground(lipgloss.Color("250")) + styleDescription = lipgloss.NewStyle().Foreground(lipgloss.Color("240")).Italic(true) + styleSelectedOption = lipgloss.NewStyle(). + Bold(true). + Foreground(lipgloss.Color("0")). + Background(lipgloss.Color("212")). + Padding(0, 1) + styleUnselectedOption = lipgloss.NewStyle().Foreground(lipgloss.Color("250")).Padding(0, 1) + styleFocusedUnselected = lipgloss.NewStyle().Foreground(lipgloss.Color("255")).Padding(0, 1) + styleMultiSelectChecked = lipgloss.NewStyle().Foreground(lipgloss.Color("42")).Padding(0, 1) + styleInputFocused = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("212")).Padding(0, 1) + styleInputBlurred = lipgloss.NewStyle().Foreground(lipgloss.Color("250")).Padding(0, 1) +) + +// FormField represents a field that can be displayed in the form +type FormField struct { + Name string + Title string + Description string + Kind FieldKind + Options []string // For Select/MultiSelect + Selected int // For Select: current selection index + MultiSelect map[int]bool // For MultiSelect: selected indices + InputValue string // For Input/Number + ConfirmVal bool // For Confirm + Required bool + Validate func(string) error + Value interface{} // Pointer to store result +} + +// GroupedWizardForm is a wizard form with Tab navigation for groups +type GroupedWizardForm struct { + groups []*FormGroup + groupIndex int // Current group being edited + + // Current field within group + fieldIndex int + cursor int // Cursor within field options + + inputMode bool + inputBuf string + inputCurPos int + + width int + height int + theme *huh.Theme + quitting bool + aborted bool + + errMsg string +} + +// FormGroup represents a group of fields +type FormGroup struct { + Name string + Title string + Description string + Fields []*FormField + Optional bool // If true, this group can be collapsed + Expanded bool // If true and Optional, show fields; otherwise collapsed +} + +// NewGroupedWizardForm creates a new grouped wizard form +func NewGroupedWizardForm(groups []*FormGroup) *GroupedWizardForm { + return &GroupedWizardForm{ + groups: groups, + groupIndex: 0, + fieldIndex: 0, + cursor: 0, + width: 80, + theme: huh.ThemeCharm(), + } +} + +// WithTheme sets the theme +func (f *GroupedWizardForm) WithTheme(theme *huh.Theme) *GroupedWizardForm { + f.theme = theme + return f +} + +// Init implements tea.Model +func (f *GroupedWizardForm) Init() tea.Cmd { + f.initCursorForField() + return nil +} + +// Update implements tea.Model +func (f *GroupedWizardForm) Update(msg tea.Msg) (tea.Model, tea.Cmd) { + switch msg := msg.(type) { + case tea.KeyMsg: + // Handle input mode separately + if f.inputMode { + return f.handleInputMode(msg) + } + + key := msg.String() + + // Check if current group is a collapsed optional group + group := f.currentGroup() + isCollapsedOptional := group != nil && group.Optional && !group.Expanded + + // Number keys 1-9 for group navigation + if len(key) == 1 && key[0] >= '1' && key[0] <= '9' { + groupNum := int(key[0] - '1') + if groupNum < len(f.groups) { + f.errMsg = "" + f.saveCurrentField() + f.groupIndex = groupNum + f.fieldIndex = 0 + f.initCursorForField() + return f, nil + } + } + + switch key { + case "ctrl+c", "esc": + f.aborted = true + f.quitting = true + return f, tea.Quit + + case "tab": + // Next group + f.errMsg = "" + f.saveCurrentField() + f.nextGroup() + + case "shift+tab": + // Previous group + f.errMsg = "" + f.saveCurrentField() + f.prevGroup() + + case "up", "k": + if isCollapsedOptional { + break // No field navigation in collapsed group + } + f.errMsg = "" + f.saveCurrentField() + f.prevField() + + case "down", "j": + if isCollapsedOptional { + break // No field navigation in collapsed group + } + f.errMsg = "" + f.saveCurrentField() + f.nextField() + + case "left", "h": + if isCollapsedOptional { + break + } + f.errMsg = "" + f.prevOption() + + case "right", "l": + if isCollapsedOptional { + break + } + f.errMsg = "" + f.nextOption() + + case " ": + f.errMsg = "" + // Handle collapsed optional group - expand it + if isCollapsedOptional { + group.Expanded = true + f.fieldIndex = 0 + f.initCursorForField() + break + } + field := f.currentField() + if field == nil { + break + } + if field.Kind == KindMultiSelect { + f.toggleSelection() + } else if field.Kind == KindConfirm { + f.cursor = 1 - f.cursor + f.saveCurrentField() + } + + case "ctrl+d": + return f.trySubmit() + + case "enter": + // Handle collapsed optional group - expand it + if isCollapsedOptional { + f.errMsg = "" + group.Expanded = true + f.fieldIndex = 0 + f.initCursorForField() + break + } + field := f.currentField() + if field == nil { + return f.trySubmit() + } + if field.Kind == KindInput || field.Kind == KindNumber { + f.errMsg = "" + f.inputMode = true + f.inputBuf = field.InputValue + f.inputCurPos = len(f.inputBuf) + } else { + return f.trySubmit() + } + + case "c": + // Collapse current optional group if expanded + if group != nil && group.Optional && group.Expanded { + f.errMsg = "" + group.Expanded = false + f.fieldIndex = 0 + } + + case "a": + if isCollapsedOptional { + break + } + if f.currentField() != nil && f.currentField().Kind == KindMultiSelect { + f.errMsg = "" + f.selectAll() + } + + case "n": + if isCollapsedOptional { + break + } + field := f.currentField() + if field != nil { + if field.Kind == KindMultiSelect { + f.errMsg = "" + f.deselectAll() + } else if field.Kind == KindConfirm { + f.errMsg = "" + f.cursor = 1 + f.saveCurrentField() + } + } + + case "y": + if isCollapsedOptional { + break + } + if f.currentField() != nil && f.currentField().Kind == KindConfirm { + f.errMsg = "" + f.cursor = 0 + f.saveCurrentField() + } + } + + case tea.WindowSizeMsg: + f.width = msg.Width + f.height = msg.Height + } + + return f, nil +} + +// handleInputMode handles key events when in text input mode +func (f *GroupedWizardForm) handleInputMode(msg tea.KeyMsg) (tea.Model, tea.Cmd) { + switch msg.String() { + case "ctrl+c", "esc": + f.inputMode = false + f.inputBuf = "" + f.errMsg = "" + + case "enter": + field := f.currentField() + if field == nil { + f.inputMode = false + return f, nil + } + candidate := f.inputBuf + old := field.InputValue + field.InputValue = candidate + if err := f.validateField(field); err != nil { + field.InputValue = old + f.errMsg = err.Error() + return f, nil + } + f.saveCurrentField() + f.inputMode = false + f.inputBuf = "" + f.errMsg = "" + // Move to next field + if f.fieldIndex < len(f.currentGroup().Fields)-1 { + f.nextField() + } + + case "ctrl+d": + field := f.currentField() + if field == nil { + f.inputMode = false + return f.trySubmit() + } + candidate := f.inputBuf + old := field.InputValue + field.InputValue = candidate + if err := f.validateField(field); err != nil { + field.InputValue = old + f.errMsg = err.Error() + return f, nil + } + f.saveCurrentField() + f.inputMode = false + f.inputBuf = "" + f.errMsg = "" + return f.trySubmit() + + case "backspace": + f.errMsg = "" + if len(f.inputBuf) > 0 { + f.inputBuf = f.inputBuf[:len(f.inputBuf)-1] + } + + default: + f.errMsg = "" + if len(msg.String()) == 1 { + f.inputBuf += msg.String() + } else if msg.Type == tea.KeySpace { + f.inputBuf += " " + } + } + + return f, nil +} + +// View implements tea.Model +func (f *GroupedWizardForm) View() string { + var sb strings.Builder + + // Tab bar - show required groups first, then optional groups + var tabs []string + for i, group := range f.groups { + label := fmt.Sprintf("%d.%s", i+1, group.Title) + + // Add indicator for optional groups + if group.Optional { + if group.Expanded { + label = fmt.Sprintf("%d.▼ %s", i+1, group.Title) + } else { + label = fmt.Sprintf("%d.▶ %s", i+1, group.Title) + } + } + + switch { + case i == f.groupIndex: + tabs = append(tabs, styleTabActive.Render(label)) + case group.Optional && !group.Expanded: + // Collapsed optional groups are "skipped": show them dimmed instead of as completed. + tabs = append(tabs, styleHelp.Render(label)) + case f.isGroupComplete(i): + tabs = append(tabs, styleTabCompleted.Render("✓ "+label)) + default: + tabs = append(tabs, styleTabInactive.Render(label)) + } + } + sb.WriteString(strings.Join(tabs, " ")) + sb.WriteString("\n") + + // Separator + sb.WriteString(styleSeparator.Render(strings.Repeat("─", minInt(f.width, 70)))) + sb.WriteString("\n\n") + + // Render current group + group := f.currentGroup() + if group == nil || len(group.Fields) == 0 { + sb.WriteString("(No fields in this group)\n") + } else if group.Optional && !group.Expanded { + // Collapsed optional group - show toggle prompt + sb.WriteString(styleDescription.Render(fmt.Sprintf(" %s (Optional)", group.Title))) + sb.WriteString("\n\n") + sb.WriteString(styleHelp.Render(" Press Enter or Space to expand, or Tab to skip")) + sb.WriteString("\n") + } else { + // Show all fields in current group + for i, field := range group.Fields { + sb.WriteString(f.renderField(field, i == f.fieldIndex)) + sb.WriteString("\n") + } + } + + // Error message + if strings.TrimSpace(f.errMsg) != "" { + sb.WriteString("\n") + sb.WriteString(styleError.Render("Error: " + f.errMsg)) + } + + // Help text + sb.WriteString("\n") + sb.WriteString(f.renderHelp()) + + return sb.String() +} + +// renderField renders a single field with all its options visible +func (f *GroupedWizardForm) renderField(field *FormField, isFocused bool) string { + var sb strings.Builder + + // Title with focus indicator + if isFocused { + sb.WriteString(styleFocusedTitle.Render("> " + field.Title)) + } else { + sb.WriteString(styleNormalTitle.Render(" " + field.Title)) + } + + // Description on same line if short + if field.Description != "" && len(field.Description) < 40 { + sb.WriteString(styleDescription.Render(" " + field.Description)) + } + sb.WriteString("\n") + + // Render options based on field kind + sb.WriteString(" ") + switch field.Kind { + case KindSelect: + sb.WriteString(f.renderSelectOptions(field, isFocused)) + case KindMultiSelect: + sb.WriteString(f.renderMultiSelectOptions(field, isFocused)) + case KindConfirm: + sb.WriteString(f.renderConfirmOptions(field, isFocused)) + case KindInput, KindNumber: + sb.WriteString(f.renderInputField(field, isFocused)) + } + + return sb.String() +} + +// selectOptionStyle returns the appropriate style based on focus and selection state +func selectOptionStyle(isFocused, isSelected bool) lipgloss.Style { + if isSelected { + return styleSelectedOption + } + if isFocused { + return styleFocusedUnselected + } + return styleUnselectedOption +} + +func (f *GroupedWizardForm) renderSelectOptions(field *FormField, isFocused bool) string { + parts := make([]string, 0, len(field.Options)) + for i, opt := range field.Options { + display := opt + if display == "" { + display = "(empty)" + } + style := selectOptionStyle(isFocused, i == field.Selected) + parts = append(parts, style.Render(display)) + } + return strings.Join(parts, " ") +} + +func (f *GroupedWizardForm) renderMultiSelectOptions(field *FormField, isFocused bool) string { + parts := make([]string, 0, len(field.Options)) + for i, opt := range field.Options { + marker := "○" + if field.MultiSelect[i] { + marker = "●" + } + display := fmt.Sprintf("%s %s", marker, opt) + isCursor := isFocused && i == f.cursor + + var style lipgloss.Style + switch { + case isCursor: + style = styleSelectedOption + case field.MultiSelect[i]: + style = styleMultiSelectChecked + case isFocused: + style = styleFocusedUnselected + default: + style = styleUnselectedOption + } + parts = append(parts, style.Render(display)) + } + return strings.Join(parts, " ") +} + +func (f *GroupedWizardForm) renderConfirmOptions(field *FormField, isFocused bool) string { + yesStyle := selectOptionStyle(isFocused, field.ConfirmVal) + noStyle := selectOptionStyle(isFocused, !field.ConfirmVal) + return yesStyle.Render("Yes") + " " + noStyle.Render("No") +} + +func (f *GroupedWizardForm) renderInputField(field *FormField, isFocused bool) string { + if isFocused && f.inputMode { + return styleInputFocused.Render("[" + f.inputBuf + "█]") + } + display := field.InputValue + if display == "" { + display = "(empty)" + } + if isFocused { + return styleInputFocused.Render("["+display+"]") + styleDescription.Render(" Enter to edit") + } + return styleInputBlurred.Render("[" + display + "]") +} + +func (f *GroupedWizardForm) renderHelp() string { + group := f.currentGroup() + + // Check if current group is a collapsed optional group + if group != nil && group.Optional && !group.Expanded { + return styleHelp.Render("Enter/Space: expand Tab: skip group 1-9: jump Ctrl+D: submit") + } + + // Check if current group is an expanded optional group + if group != nil && group.Optional && group.Expanded { + field := f.currentField() + baseHelp := "↑/↓: field c: collapse Tab: group " + if field == nil { + return styleHelp.Render(baseHelp + "Ctrl+D: submit") + } + switch field.Kind { + case KindMultiSelect: + return styleHelp.Render(baseHelp + "Space: toggle a: all Ctrl+D: submit") + case KindConfirm: + return styleHelp.Render(baseHelp + "←/→: toggle Ctrl+D: submit") + case KindInput, KindNumber: + if f.inputMode { + return styleHelp.Render("Enter: save Esc: cancel Ctrl+D: save & submit") + } + return styleHelp.Render(baseHelp + "Enter: edit Ctrl+D: submit") + default: + return styleHelp.Render(baseHelp + "←/→: select Ctrl+D: submit") + } + } + + field := f.currentField() + if field == nil { + return styleHelp.Render("Tab: next group Shift+Tab: prev group 1-9: jump to group Ctrl+D: submit") + } + + baseHelp := "↑/↓: field Tab/Shift+Tab: group 1-9: jump " + + switch field.Kind { + case KindMultiSelect: + return styleHelp.Render(baseHelp + "←/→: move Space: toggle a: all n: none Ctrl+D: submit") + case KindConfirm: + return styleHelp.Render(baseHelp + "←/→: toggle y: Yes n: No Ctrl+D: submit") + case KindInput, KindNumber: + if f.inputMode { + return styleHelp.Render("Enter: save Esc: cancel Ctrl+D: save & submit") + } + return styleHelp.Render(baseHelp + "Enter: edit Ctrl+D: submit") + default: + return styleHelp.Render(baseHelp + "←/→: select Ctrl+D: submit") + } +} + +// Helper methods + +func (f *GroupedWizardForm) currentGroup() *FormGroup { + if f.groupIndex >= 0 && f.groupIndex < len(f.groups) { + return f.groups[f.groupIndex] + } + return nil +} + +func (f *GroupedWizardForm) currentField() *FormField { + group := f.currentGroup() + if group == nil { + return nil + } + if f.fieldIndex >= 0 && f.fieldIndex < len(group.Fields) { + return group.Fields[f.fieldIndex] + } + return nil +} + +func (f *GroupedWizardForm) nextGroup() { + f.groupIndex++ + if f.groupIndex >= len(f.groups) { + f.groupIndex = 0 + } + f.fieldIndex = 0 + f.initCursorForField() +} + +func (f *GroupedWizardForm) prevGroup() { + f.groupIndex-- + if f.groupIndex < 0 { + f.groupIndex = len(f.groups) - 1 + } + f.fieldIndex = 0 + f.initCursorForField() +} + +func (f *GroupedWizardForm) nextField() { + group := f.currentGroup() + if group == nil { + return + } + f.fieldIndex++ + if f.fieldIndex >= len(group.Fields) { + f.fieldIndex = 0 + } + f.initCursorForField() +} + +func (f *GroupedWizardForm) prevField() { + group := f.currentGroup() + if group == nil { + return + } + f.fieldIndex-- + if f.fieldIndex < 0 { + f.fieldIndex = len(group.Fields) - 1 + } + f.initCursorForField() +} + +func (f *GroupedWizardForm) initCursorForField() { + field := f.currentField() + if field == nil { + f.cursor = 0 + return + } + switch field.Kind { + case KindSelect: + f.cursor = field.Selected + case KindConfirm: + if field.ConfirmVal { + f.cursor = 0 + } else { + f.cursor = 1 + } + default: + f.cursor = 0 + } +} + +// wrapIndex wraps index in range [0, max) with cycling +func wrapIndex(index, delta, max int) int { + if max <= 0 { + return 0 + } + return (index + delta + max) % max +} + +func (f *GroupedWizardForm) nextOption() { + field := f.currentField() + if field == nil { + return + } + switch field.Kind { + case KindSelect: + f.cursor = wrapIndex(f.cursor, 1, len(field.Options)) + field.Selected = f.cursor + f.saveCurrentField() + case KindMultiSelect: + f.cursor = wrapIndex(f.cursor, 1, len(field.Options)) + case KindConfirm: + f.cursor = 1 - f.cursor + f.saveCurrentField() + case KindInput, KindNumber: + if !f.inputMode { + f.saveCurrentField() + f.nextField() + } + } +} + +func (f *GroupedWizardForm) prevOption() { + field := f.currentField() + if field == nil { + return + } + switch field.Kind { + case KindSelect: + f.cursor = wrapIndex(f.cursor, -1, len(field.Options)) + field.Selected = f.cursor + f.saveCurrentField() + case KindMultiSelect: + f.cursor = wrapIndex(f.cursor, -1, len(field.Options)) + case KindConfirm: + f.cursor = 1 - f.cursor + f.saveCurrentField() + case KindInput, KindNumber: + if !f.inputMode { + f.saveCurrentField() + f.prevField() + } + } +} + +func (f *GroupedWizardForm) ensureMultiSelect(field *FormField) { + if field.MultiSelect == nil { + field.MultiSelect = make(map[int]bool) + } +} + +func (f *GroupedWizardForm) toggleSelection() { + field := f.currentField() + if field == nil { + return + } + f.ensureMultiSelect(field) + field.MultiSelect[f.cursor] = !field.MultiSelect[f.cursor] + f.saveCurrentField() +} + +func (f *GroupedWizardForm) selectAll() { + field := f.currentField() + if field == nil { + return + } + f.ensureMultiSelect(field) + for i := range field.Options { + field.MultiSelect[i] = true + } + f.saveCurrentField() +} + +func (f *GroupedWizardForm) deselectAll() { + field := f.currentField() + if field == nil { + return + } + field.MultiSelect = make(map[int]bool) + f.saveCurrentField() +} + +func (f *GroupedWizardForm) saveCurrentField() { + field := f.currentField() + if field == nil { + return + } + + switch field.Kind { + case KindSelect: + if ptr, ok := field.Value.(*string); ok && ptr != nil { + if field.Selected >= 0 && field.Selected < len(field.Options) { + *ptr = field.Options[field.Selected] + } + } + case KindMultiSelect: + if ptr, ok := field.Value.(*[]string); ok && ptr != nil { + var selected []string + for i, opt := range field.Options { + if field.MultiSelect[i] { + selected = append(selected, opt) + } + } + *ptr = selected + } + case KindConfirm: + field.ConfirmVal = (f.cursor == 0) + if ptr, ok := field.Value.(*bool); ok && ptr != nil { + *ptr = field.ConfirmVal + } + case KindInput, KindNumber: + if ptr, ok := field.Value.(*string); ok && ptr != nil { + *ptr = field.InputValue + } + } +} + +func (f *GroupedWizardForm) isGroupComplete(groupIdx int) bool { + if groupIdx < 0 || groupIdx >= len(f.groups) { + return false + } + group := f.groups[groupIdx] + + // Collapsed optional groups are considered "complete" (skipped) + if group.Optional && !group.Expanded { + return true + } + + for _, field := range group.Fields { + if err := f.validateField(field); err != nil { + return false + } + } + return true +} + +func (f *GroupedWizardForm) trySubmit() (tea.Model, tea.Cmd) { + f.saveCurrentField() + if err := f.validateAllFields(); err != nil { + return f, nil + } + f.quitting = true + return f, tea.Quit +} + +func (f *GroupedWizardForm) validateAllFields() error { + for gi, group := range f.groups { + // Skip collapsed optional groups (user chose to skip) + if group.Optional && !group.Expanded { + continue + } + + for fi, field := range group.Fields { + if err := f.validateField(field); err != nil { + f.errMsg = err.Error() + f.inputMode = false + f.inputBuf = "" + f.groupIndex = gi + f.fieldIndex = fi + f.initCursorForField() + return err + } + } + } + f.errMsg = "" + return nil +} + +// validateStringField validates string-like fields (Select, Input) +func (f *GroupedWizardForm) validateStringField(value string, field *FormField, label string) error { + if !field.Required && field.Validate == nil { + return nil + } + var required func(string) error + if field.Required { + required = requiredStringValidator(label) + } + return chainStringValidators(required, field.Validate)(value) +} + +// requiredError returns a formatted required error message +func requiredError(label string) error { + if label != "" { + return fmt.Errorf("%s is required", label) + } + return fmt.Errorf("value is required") +} + +func (f *GroupedWizardForm) validateField(field *FormField) error { + if field == nil { + return nil + } + + label := field.Title + if strings.TrimSpace(label) == "" { + label = field.Name + } + + switch field.Kind { + case KindSelect: + val := "" + if field.Selected >= 0 && field.Selected < len(field.Options) { + val = field.Options[field.Selected] + } + return f.validateStringField(val, field, label) + + case KindMultiSelect: + if !field.Required { + return nil + } + for _, selected := range field.MultiSelect { + if selected { + return nil + } + } + return requiredError(label) + + case KindInput: + return f.validateStringField(field.InputValue, field, label) + + case KindNumber: + s := strings.TrimSpace(field.InputValue) + if s == "" { + if field.Required { + return requiredError(label) + } + return nil + } + if _, err := strconv.Atoi(s); err != nil { + return fmt.Errorf("please enter a valid number") + } + if field.Validate != nil { + return field.Validate(s) + } + return nil + + case KindConfirm: + return nil + default: + return nil + } +} + +// Run executes the grouped form +func (f *GroupedWizardForm) Run() error { + // Prevent lipgloss from sending OSC terminal queries (like \x1b]11;?) + // which can conflict with readline's input handling and cause garbled output. + // We set HasDarkBackground once at startup to avoid runtime OSC queries. + lipglossInitOnce.Do(func() { + lipgloss.SetHasDarkBackground(true) + }) + + p := tea.NewProgram(f) + _, err := p.Run() + if err != nil { + return err + } + if f.aborted { + return fmt.Errorf("wizard aborted") + } + // Final save of all fields + for gi := range f.groups { + for fi := range f.groups[gi].Fields { + f.groupIndex = gi + f.fieldIndex = fi + f.saveCurrentField() + } + } + return nil +} + +// Aborted returns true if the user cancelled +func (f *GroupedWizardForm) Aborted() bool { + return f.aborted +} + +// requiredStringValidator creates a validator that checks for non-empty strings +func requiredStringValidator(label string) func(string) error { + return func(s string) error { + if strings.TrimSpace(s) == "" { + if label != "" { + return fmt.Errorf("%s is required", label) + } + return fmt.Errorf("value is required") + } + return nil + } +} + +// chainStringValidators chains multiple string validators together +func chainStringValidators(validators ...func(string) error) func(string) error { + return func(s string) error { + for _, v := range validators { + if v == nil { + continue + } + if err := v(s); err != nil { + return err + } + } + return nil + } +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/client/wizard/grouped_form_test.go b/client/wizard/grouped_form_test.go new file mode 100644 index 00000000..9f19bb74 --- /dev/null +++ b/client/wizard/grouped_form_test.go @@ -0,0 +1,239 @@ +package wizard + +import ( + "fmt" + "strings" + "testing" +) + +func TestGroupedWizardFormStructure(t *testing.T) { + // Get build_beacon wizard + w, ok := GetTemplate("build_beacon") + if !ok { + t.Fatal("build_beacon template not found") + } + + if !w.IsGrouped() { + t.Fatal("build_beacon should be grouped") + } + + // Print group structure + fmt.Println("=== Build Beacon Wizard Groups ===") + for i, g := range w.Groups { + fmt.Printf("\n[%d] %s - %s\n", i+1, g.Title, g.Description) + for j, f := range g.Fields { + fmt.Printf(" %d.%d %s (%v)\n", i+1, j+1, f.Title, f.Type) + } + } + fmt.Printf("\nTotal: %d groups, %d fields\n", len(w.Groups), len(w.Fields)) +} + +func TestFormGroupConversion(t *testing.T) { + // Create a simple grouped wizard + w := NewWizard("test", "Test Wizard"). + NewGroup("g1", "Group 1"). + WithDescription("First group"). + Select("field1", "Field 1", []string{"a", "b", "c"}).Field().EndGroup(). + Input("field2", "Field 2", "default").Field().EndGroup(). + End(). + NewGroup("g2", "Group 2"). + WithDescription("Second group"). + Confirm("field3", "Field 3", true).Field().EndGroup(). + Number("field4", "Field 4", 42).Field().EndGroup(). + End() + + if len(w.Groups) != 2 { + t.Errorf("Expected 2 groups, got %d", len(w.Groups)) + } + + if len(w.Groups[0].Fields) != 2 { + t.Errorf("Expected 2 fields in group 1, got %d", len(w.Groups[0].Fields)) + } + + if len(w.Groups[1].Fields) != 2 { + t.Errorf("Expected 2 fields in group 2, got %d", len(w.Groups[1].Fields)) + } + + if len(w.Fields) != 4 { + t.Errorf("Expected 4 total fields, got %d", len(w.Fields)) + } +} + +func TestGroupedFormInit(t *testing.T) { + groups := []*FormGroup{ + { + Name: "g1", + Title: "Group 1", + Description: "First group", + Fields: []*FormField{ + { + Name: "f1", + Title: "Field 1", + Kind: KindSelect, + Options: []string{"opt1", "opt2", "opt3"}, + }, + }, + }, + { + Name: "g2", + Title: "Group 2", + Description: "Second group", + Fields: []*FormField{ + { + Name: "f2", + Title: "Field 2", + Kind: KindInput, + }, + }, + }, + } + + form := NewGroupedWizardForm(groups) + form.Init() + + if form.groupIndex != 0 { + t.Errorf("Expected groupIndex 0, got %d", form.groupIndex) + } + + if form.fieldIndex != 0 { + t.Errorf("Expected fieldIndex 0, got %d", form.fieldIndex) + } +} + +func TestIsGroupComplete_RequiredSelectEmpty(t *testing.T) { + form := NewGroupedWizardForm([]*FormGroup{{ + Name: "g1", + Title: "Group 1", + Fields: []*FormField{{ + Name: "f1", + Title: "Field 1", + Kind: KindSelect, + Options: []string{""}, + Selected: 0, + Required: true, + }}, + }}) + + if form.isGroupComplete(0) { + t.Fatal("expected group to be incomplete when required select is empty") + } +} + +func TestOptionalGroupCollapsed(t *testing.T) { + // Test that collapsed optional groups are considered complete + form := NewGroupedWizardForm([]*FormGroup{ + { + Name: "required", + Title: "Required Group", + Optional: false, + Fields: []*FormField{{ + Name: "f1", + Title: "Field 1", + Kind: KindSelect, + Options: []string{"opt1", "opt2"}, + Selected: 0, + }}, + }, + { + Name: "optional", + Title: "Optional Group", + Optional: true, + Expanded: false, // Collapsed by default + Fields: []*FormField{{ + Name: "f2", + Title: "Field 2", + Kind: KindInput, + Required: true, // This is required but should be skipped when collapsed + }}, + }, + }) + + // Required group should be complete (select has default) + if !form.isGroupComplete(0) { + t.Fatal("expected required group to be complete") + } + + // Collapsed optional group should be considered complete (skipped) + if !form.isGroupComplete(1) { + t.Fatal("expected collapsed optional group to be complete") + } + + // Expand the optional group + form.groups[1].Expanded = true + + // Now optional group should not be complete (required field is empty) + if form.isGroupComplete(1) { + t.Fatal("expected expanded optional group with empty required field to be incomplete") + } +} + +func TestOptionalGroupInWizard(t *testing.T) { + // Test creating a wizard with optional groups + w := NewWizard("test", "Test Wizard"). + NewGroup("basic", "Basic Settings"). + WithDescription("Required configuration"). + Select("target", "Target OS", []string{"windows", "linux"}).Field().EndGroup(). + End(). + NewGroup("advanced", "Advanced Settings"). + WithDescription("Optional configuration"). + AsOptional(). + Input("custom_flag", "Custom Flag", "").Field().EndGroup(). + End() + + if len(w.Groups) != 2 { + t.Fatalf("Expected 2 groups, got %d", len(w.Groups)) + } + + // First group should not be optional + if w.Groups[0].Optional { + t.Error("First group should not be optional") + } + + // Second group should be optional + if !w.Groups[1].Optional { + t.Error("Second group should be optional") + } + + // Optional group should be collapsed by default + if w.Groups[1].Expanded { + t.Error("Optional group should be collapsed by default") + } +} + +func TestOptionalCollapsedGroupNotMarkedCompletedInTabs(t *testing.T) { + form := NewGroupedWizardForm([]*FormGroup{ + { + Name: "required", + Title: "Required Group", + Fields: []*FormField{{ + Name: "f1", + Title: "Field 1", + Kind: KindInput, + Required: true, + // Leave empty so required group isn't "complete". + InputValue: "", + }}, + }, + { + Name: "optional", + Title: "Optional Group", + Optional: true, + Expanded: false, // Collapsed by default + Fields: []*FormField{{ + Name: "f2", + Title: "Field 2", + Kind: KindInput, + Required: true, + // Leave empty; should be skipped when collapsed. + InputValue: "", + }}, + }, + }) + + // Current group is index 0. The optional group is collapsed and should be dimmed, + // not shown as "completed" with a checkmark in the tab bar. + view := form.View() + if strings.Contains(view, "✓") { + t.Fatalf("expected collapsed optional group not to be marked completed in tabs; got view:\n%s", view) + } +} diff --git a/client/wizard/lua.go b/client/wizard/lua.go new file mode 100644 index 00000000..105c6612 --- /dev/null +++ b/client/wizard/lua.go @@ -0,0 +1,733 @@ +package wizard + +import ( + "errors" + "strings" + "sync" + + "github.com/chainreactors/malice-network/helper/intermediate" + "github.com/chainreactors/mals" + lua "github.com/yuin/gopher-lua" +) + +const luaWizardTypeName = "wizard" + +// RegisterLuaFunctions registers wizard functions to the VM functions map +func RegisterLuaFunctions(vmFns map[string]lua.LGFunction) { + vmFns["wizard"] = luaWizardNew + vmFns["wizard_template"] = luaWizardTemplate + vmFns["wizard_templates"] = luaWizardListTemplates + vmFns["wizard_from_spec"] = luaWizardFromSpec + vmFns["wizard_from_file"] = luaWizardFromFile + vmFns["wizard_register_template"] = luaWizardRegisterTemplate +} + +// SetupMetatable sets up the wizard metatable in a Lua VM +func SetupMetatable(L *lua.LState) { + mt := L.NewTypeMetatable(luaWizardTypeName) + L.SetField(mt, "__index", L.SetFuncs(L.NewTable(), wizardMethods)) +} + +// wizardMethods contains all methods available on wizard userdata +var wizardMethods = map[string]lua.LGFunction{ + "input": wizardInput, + "text": wizardText, + "select": wizardSelect, + "multiselect": wizardMultiSelect, + "confirm": wizardConfirm, + "number": wizardNumber, + "filepath": wizardFilePath, + "description": wizardDescription, + "clone": wizardClone, + "run": wizardRun, + "get_field": wizardGetField, + "field_count": wizardFieldCount, +} + +// luaWizardNew creates a new wizard (Lua: wizard(id, title)) +func luaWizardNew(L *lua.LState) int { + id := L.CheckString(1) + title := L.OptString(2, "") + + wiz := NewWizard(id, title) + + ud := L.NewUserData() + ud.Value = wiz + L.SetMetatable(ud, L.GetTypeMetatable(luaWizardTypeName)) + L.Push(ud) + return 1 +} + +// luaWizardTemplate loads a predefined template (Lua: wizard_template(name)) +func luaWizardTemplate(L *lua.LState) int { + name := L.CheckString(1) + + wiz, ok := GetTemplate(name) + if !ok { + L.Push(lua.LNil) + L.Push(lua.LString("template not found: " + name)) + return 2 + } + + ud := L.NewUserData() + ud.Value = wiz + L.SetMetatable(ud, L.GetTypeMetatable(luaWizardTypeName)) + L.Push(ud) + return 1 +} + +// luaWizardListTemplates returns available template names (Lua: wizard_templates()) +func luaWizardListTemplates(L *lua.LState) int { + names := ListTemplates() + tbl := L.NewTable() + for i, name := range names { + tbl.RawSetInt(i+1, lua.LString(name)) + } + L.Push(tbl) + return 1 +} + +// luaWizardFromSpec builds a wizard from a spec table (Lua: wizard_from_spec(spec)) +func luaWizardFromSpec(L *lua.LState) int { + raw := mals.ConvertLuaValueToGo(L.CheckTable(1)) + specMap, ok := raw.(map[string]interface{}) + if !ok { + L.Push(lua.LNil) + L.Push(lua.LString("spec must be a table")) + return 2 + } + spec, err := SpecFromMap(specMap) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + wiz, err := NewWizardFromSpec(spec) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + ud := L.NewUserData() + ud.Value = wiz + L.SetMetatable(ud, L.GetTypeMetatable(luaWizardTypeName)) + L.Push(ud) + return 1 +} + +// luaWizardFromFile loads a wizard spec from file (Lua: wizard_from_file(path)) +func luaWizardFromFile(L *lua.LState) int { + path := L.CheckString(1) + wiz, err := NewWizardFromFile(path) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + ud := L.NewUserData() + ud.Value = wiz + L.SetMetatable(ud, L.GetTypeMetatable(luaWizardTypeName)) + L.Push(ud) + return 1 +} + +// luaWizardRegisterTemplate registers a template from a wizard or a spec table. +// Lua: +// +// wizard_register_template(name, wiz) +// wizard_register_template(name, specTable) +func luaWizardRegisterTemplate(L *lua.LState) int { + name := L.CheckString(1) + if name == "" { + L.Push(lua.LNil) + L.Push(lua.LString("template name is required")) + return 2 + } + + switch v := L.Get(2).(type) { + case *lua.LUserData: + wiz, ok := v.Value.(*Wizard) + if !ok { + L.Push(lua.LNil) + L.Push(lua.LString("second arg must be wizard userdata or spec table")) + return 2 + } + base := wiz.Clone() + RegisterTemplate(name, func() *Wizard { return base.Clone() }) + L.Push(lua.LTrue) + return 1 + case *lua.LTable: + raw := mals.ConvertLuaValueToGo(v) + specMap, ok := raw.(map[string]interface{}) + if !ok { + L.Push(lua.LNil) + L.Push(lua.LString("spec must be a table")) + return 2 + } + spec, err := SpecFromMap(specMap) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + if spec.ID == "" { + spec.ID = name + } + if err := RegisterTemplateFromSpec(name, spec); err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + L.Push(lua.LTrue) + return 1 + default: + L.Push(lua.LNil) + L.Push(lua.LString("second arg must be wizard userdata or spec table")) + return 2 + } +} + +// checkWizard extracts wizard from userdata +func checkWizard(L *lua.LState, n int) *Wizard { + ud := L.CheckUserData(n) + if v, ok := ud.Value.(*Wizard); ok { + return v + } + L.ArgError(n, "wizard expected") + return nil +} + +type luaFieldOptions struct { + Description string + Required bool +} + +// luaFieldArgs holds parsed arguments for field creation methods +type luaFieldArgs struct { + Name string + Title string + Default lua.LValue + Options luaFieldOptions + OptsIdx int +} + +// parseLuaFieldArgs extracts name, title, optional default value, and options table +// from Lua stack starting at position 2 (position 1 is self). +// defaultIdx is the position where the default value is expected (usually 4). +func parseLuaFieldArgs(L *lua.LState, defaultIdx int) luaFieldArgs { + args := luaFieldArgs{ + Name: L.CheckString(2), + Title: L.CheckString(3), + } + + if L.GetTop() >= defaultIdx { + v := L.Get(defaultIdx) + if v.Type() == lua.LTTable { + args.OptsIdx = defaultIdx + } else { + args.Default = v + if L.GetTop() >= defaultIdx+1 && L.Get(defaultIdx+1).Type() == lua.LTTable { + args.OptsIdx = defaultIdx + 1 + } + } + } + + args.Options = parseLuaFieldOptions(L, args.OptsIdx) + return args +} + +func parseLuaFieldOptions(L *lua.LState, idx int) luaFieldOptions { + if idx <= 0 || L.Get(idx).Type() != lua.LTTable { + return luaFieldOptions{} + } + tbl := L.CheckTable(idx) + + desc := lua.LVAsString(L.GetField(tbl, "description")) + if desc == "" { + desc = lua.LVAsString(L.GetField(tbl, "desc")) + } + required := false + if v := L.GetField(tbl, "required"); v != lua.LNil { + required = lua.LVAsBool(v) + } + + return luaFieldOptions{ + Description: desc, + Required: required, + } +} + +func luaStringSlice(L *lua.LState, tbl *lua.LTable) ([]string, error) { + out := make([]string, 0, tbl.Len()) + for i := 1; i <= tbl.Len(); i++ { + v := tbl.RawGetInt(i) + if v == lua.LNil { + continue + } + out = append(out, lua.LVAsString(v)) + } + return out, nil +} + +// wizardInput adds an input field (Lua: wiz:input(name, title, default)) +func wizardInput(L *lua.LState) int { + wiz := checkWizard(L, 1) + args := parseLuaFieldArgs(L, 4) + + defaultVal := "" + if args.Default != nil && args.Default != lua.LNil { + s, ok := args.Default.(lua.LString) + if !ok { + L.TypeError(4, lua.LTString) + return 0 + } + defaultVal = string(s) + } + + wiz.AddField(&WizardField{ + Name: args.Name, + Title: args.Title, + Description: args.Options.Description, + Type: FieldInput, + Default: defaultVal, + Required: args.Options.Required, + }) + + L.Push(L.Get(1)) + return 1 +} + +// wizardText adds a text field (Lua: wiz:text(name, title, default)) +func wizardText(L *lua.LState) int { + wiz := checkWizard(L, 1) + args := parseLuaFieldArgs(L, 4) + + defaultVal := "" + if args.Default != nil && args.Default != lua.LNil { + s, ok := args.Default.(lua.LString) + if !ok { + L.TypeError(4, lua.LTString) + return 0 + } + defaultVal = string(s) + } + + wiz.AddField(&WizardField{ + Name: args.Name, + Title: args.Title, + Description: args.Options.Description, + Type: FieldText, + Default: defaultVal, + Required: args.Options.Required, + }) + + L.Push(L.Get(1)) + return 1 +} + +// wizardSelect adds a select field (Lua: wiz:select(name, title, options)) +func wizardSelect(L *lua.LState) int { + wiz := checkWizard(L, 1) + name := L.CheckString(2) + title := L.CheckString(3) + + optionsTbl := L.CheckTable(4) + options, err := luaStringSlice(L, optionsTbl) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + defaultVal := "" + optsIdx := 0 + if L.GetTop() >= 5 { + if L.Get(5).Type() == lua.LTTable { + optsIdx = 5 + } else { + defaultVal = L.OptString(5, "") + if L.GetTop() >= 6 && L.Get(6).Type() == lua.LTTable { + optsIdx = 6 + } + } + } + opts := parseLuaFieldOptions(L, optsIdx) + wiz.AddField(&WizardField{ + Name: name, + Title: title, + Description: opts.Description, + Type: FieldSelect, + Options: options, + Default: defaultVal, + Required: opts.Required, + }) + + L.Push(L.Get(1)) + return 1 +} + +// wizardMultiSelect adds a multi-select field (Lua: wiz:multiselect(name, title, options)) +func wizardMultiSelect(L *lua.LState) int { + wiz := checkWizard(L, 1) + name := L.CheckString(2) + title := L.CheckString(3) + + optionsTbl := L.CheckTable(4) + options, err := luaStringSlice(L, optionsTbl) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + var defaults []string + optsIdx := 0 + if L.GetTop() >= 5 && L.Get(5).Type() == lua.LTTable { + t := L.CheckTable(5) + if L.GetField(t, "required") != lua.LNil || L.GetField(t, "description") != lua.LNil || L.GetField(t, "desc") != lua.LNil { + optsIdx = 5 + } else { + defaults, err = luaStringSlice(L, t) + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + if L.GetTop() >= 6 && L.Get(6).Type() == lua.LTTable { + optsIdx = 6 + } + } + } + opts := parseLuaFieldOptions(L, optsIdx) + + field := &WizardField{ + Name: name, + Title: title, + Description: opts.Description, + Type: FieldMultiSelect, + Options: options, + Required: opts.Required, + } + if defaults != nil { + field.Default = defaults + } + wiz.AddField(field) + + L.Push(L.Get(1)) + return 1 +} + +// wizardConfirm adds a confirm field (Lua: wiz:confirm(name, title, default)) +func wizardConfirm(L *lua.LState) int { + wiz := checkWizard(L, 1) + args := parseLuaFieldArgs(L, 4) + + defaultVal := false + if args.Default != nil && args.Default != lua.LNil { + b, ok := args.Default.(lua.LBool) + if !ok { + L.TypeError(4, lua.LTBool) + return 0 + } + defaultVal = bool(b) + } + + wiz.AddField(&WizardField{ + Name: args.Name, + Title: args.Title, + Description: args.Options.Description, + Type: FieldConfirm, + Default: defaultVal, + Required: args.Options.Required, + }) + + L.Push(L.Get(1)) + return 1 +} + +// wizardNumber adds a number field (Lua: wiz:number(name, title, default)) +func wizardNumber(L *lua.LState) int { + wiz := checkWizard(L, 1) + args := parseLuaFieldArgs(L, 4) + + defaultVal := 0 + if args.Default != nil && args.Default != lua.LNil { + n, ok := args.Default.(lua.LNumber) + if !ok { + L.TypeError(4, lua.LTNumber) + return 0 + } + defaultVal = int(n) + } + + wiz.AddField(&WizardField{ + Name: args.Name, + Title: args.Title, + Description: args.Options.Description, + Type: FieldNumber, + Default: defaultVal, + Required: args.Options.Required, + }) + + L.Push(L.Get(1)) + return 1 +} + +// wizardFilePath adds a file path field (Lua: wiz:filepath(name, title)) +func wizardFilePath(L *lua.LState) int { + wiz := checkWizard(L, 1) + args := parseLuaFieldArgs(L, 4) + + defaultVal := "" + if args.Default != nil && args.Default != lua.LNil { + s, ok := args.Default.(lua.LString) + if !ok { + L.TypeError(4, lua.LTString) + return 0 + } + defaultVal = string(s) + } + + wiz.AddField(&WizardField{ + Name: args.Name, + Title: args.Title, + Description: args.Options.Description, + Type: FieldFilePath, + Default: defaultVal, + Required: args.Options.Required, + }) + + L.Push(L.Get(1)) + return 1 +} + +// wizardDescription sets the wizard description (Lua: wiz:description(desc)) +func wizardDescription(L *lua.LState) int { + wiz := checkWizard(L, 1) + desc := L.CheckString(2) + + wiz.WithDescription(desc) + + L.Push(L.Get(1)) + return 1 +} + +// wizardClone creates a copy of the wizard (Lua: wiz:clone()) +func wizardClone(L *lua.LState) int { + wiz := checkWizard(L, 1) + clone := wiz.Clone() + + ud := L.NewUserData() + ud.Value = clone + L.SetMetatable(ud, L.GetTypeMetatable(luaWizardTypeName)) + L.Push(ud) + return 1 +} + +// wizardRun executes the wizard (Lua: wiz:run()) +func wizardRun(L *lua.LState) int { + wiz := checkWizard(L, 1) + + runner := NewRunner(wiz) + result, err := runner.Run() + + if err != nil { + L.Push(lua.LNil) + L.Push(lua.LString(err.Error())) + return 2 + } + + // Convert result to Lua table + resultTable := L.NewTable() + for k, v := range result.ToMap() { + resultTable.RawSetString(k, convertToLuaValue(L, v)) + } + + L.Push(resultTable) + return 1 +} + +// wizardGetField gets field info (Lua: wiz:get_field(name)) +func wizardGetField(L *lua.LState) int { + wiz := checkWizard(L, 1) + name := L.CheckString(2) + + for _, f := range wiz.Fields { + if f.Name == name { + tbl := L.NewTable() + tbl.RawSetString("name", lua.LString(f.Name)) + tbl.RawSetString("title", lua.LString(f.Title)) + tbl.RawSetString("description", lua.LString(f.Description)) + tbl.RawSetString("type", lua.LNumber(float64(f.Type))) + tbl.RawSetString("required", lua.LBool(f.Required)) + if f.Default != nil { + tbl.RawSetString("default", convertToLuaValue(L, f.Default)) + } + if len(f.Options) > 0 { + optTbl := L.NewTable() + for i, opt := range f.Options { + optTbl.RawSetInt(i+1, lua.LString(opt)) + } + tbl.RawSetString("options", optTbl) + } + L.Push(tbl) + return 1 + } + } + + L.Push(lua.LNil) + return 1 +} + +// wizardFieldCount returns the number of fields (Lua: wiz:field_count()) +func wizardFieldCount(L *lua.LState) int { + wiz := checkWizard(L, 1) + L.Push(lua.LNumber(len(wiz.Fields))) + return 1 +} + +// convertToLuaValue converts a Go value to Lua value +func convertToLuaValue(L *lua.LState, v interface{}) lua.LValue { + switch val := v.(type) { + case string: + return lua.LString(val) + case bool: + return lua.LBool(val) + case int: + return lua.LNumber(float64(val)) + case float64: + return lua.LNumber(val) + case []string: + tbl := L.NewTable() + for i, s := range val { + tbl.RawSetInt(i+1, lua.LString(s)) + } + return tbl + default: + return lua.LNil + } +} + +var registerBuiltinsOnce sync.Once + +// RegisterBuiltinFunctions registers wizard functions as builtin functions +// This allows the functions to be available in all Lua VMs +func RegisterBuiltinFunctions() { + registerBuiltinsOnce.Do(func() { + // Register wizard constructor + intermediate.RegisterFunction("wizard", func(id string, title string) (*Wizard, error) { + return NewWizard(id, title), nil + }) + + intermediate.AddHelper("wizard", &mals.Helper{ + Group: intermediate.ClientGroup, + Short: "Create a new interactive wizard form", + Input: []string{"id: wizard identifier", "title: wizard title"}, + Output: []string{"wizard"}, + Example: `local wiz = wizard("my_wizard", "My Wizard")`, + }) + + // Register template loader + intermediate.RegisterFunction("wizard_template", func(name string) (*Wizard, error) { + wiz, ok := GetTemplate(name) + if !ok { + return nil, nil + } + return wiz, nil + }) + + intermediate.AddHelper("wizard_template", &mals.Helper{ + Group: intermediate.ClientGroup, + Short: "Load a predefined wizard template", + Input: []string{"name: template name"}, + Output: []string{"wizard"}, + Example: `local wiz = wizard_template("listener_setup")`, + }) + + // Register template list + intermediate.RegisterFunction("wizard_templates", func() ([]string, error) { + return ListTemplates(), nil + }) + + intermediate.AddHelper("wizard_templates", &mals.Helper{ + Group: intermediate.ClientGroup, + Short: "List available wizard templates", + Output: []string{"template names"}, + Example: `local templates = wizard_templates()`, + }) + + // Config-driven helpers + intermediate.RegisterFunction("wizard_from_file", func(path string) (*Wizard, error) { + return NewWizardFromFile(path) + }) + + intermediate.AddHelper("wizard_from_file", &mals.Helper{ + Group: intermediate.ClientGroup, + Short: "Load a wizard from a JSON/YAML spec file", + Input: []string{"path: .json/.yaml/.yml wizard spec file"}, + Output: []string{"wizard"}, + Example: ` +-- Prefer loading from plugin resources so embedded/external plugins both work: +wiz = wizard_from_file(script_resource("wizards/priv_esc.yaml")) +`, + }) + + intermediate.RegisterFunction("wizard_from_spec", func(spec map[string]interface{}) (*Wizard, error) { + ws, err := SpecFromMap(spec) + if err != nil { + return nil, err + } + return NewWizardFromSpec(ws) + }) + + intermediate.AddHelper("wizard_from_spec", &mals.Helper{ + Group: intermediate.ClientGroup, + Short: "Build a wizard from a Lua spec table", + Input: []string{"spec: table"}, + Output: []string{"wizard"}, + Example: ` +wiz = wizard_from_spec({ + id = "my_wizard", + title = "My Wizard", + fields = { + { name = "host", title = "Host", type = "input", default = "0.0.0.0", required = true }, + }, +}) +`, + }) + + intermediate.RegisterFunction("wizard_register_template", func(name string, spec map[string]interface{}) (bool, error) { + if strings.TrimSpace(name) == "" { + return false, errors.New("template name is required") + } + ws, err := SpecFromMap(spec) + if err != nil { + return false, err + } + if ws.ID == "" { + ws.ID = name + } + if err := RegisterTemplateFromSpec(name, ws); err != nil { + return false, err + } + return true, nil + }) + + intermediate.AddHelper("wizard_register_template", &mals.Helper{ + Group: intermediate.ClientGroup, + Short: "Register a wizard template from a spec table", + Input: []string{"name: template name", "spec: table"}, + Output: []string{"ok: boolean"}, + Example: ` +wizard_register_template("priv_esc", { + title = "Privilege Escalation", + fields = { + { name = "method", title = "Method", type = "select", options = {"uac","token"} }, + }, +}) +`, + }) + }) +} diff --git a/client/wizard/result.go b/client/wizard/result.go new file mode 100644 index 00000000..e4ff456a --- /dev/null +++ b/client/wizard/result.go @@ -0,0 +1,136 @@ +package wizard + +import ( + "strconv" +) + +// WizardResult stores the results from a wizard run +type WizardResult struct { + WizardID string + Values map[string]interface{} +} + +// NewWizardResult creates a new result instance +func NewWizardResult(wizardID string) *WizardResult { + return &WizardResult{ + WizardID: wizardID, + Values: make(map[string]interface{}), + } +} + +// Set sets a value in the result +func (r *WizardResult) Set(name string, value interface{}) { + r.Values[name] = value +} + +// Get gets a raw value from the result +func (r *WizardResult) Get(name string) interface{} { + return r.Values[name] +} + +// GetString gets a string value from the result +func (r *WizardResult) GetString(name string) string { + if v, ok := r.Values[name]; ok { + switch val := v.(type) { + case string: + return val + case *string: + if val != nil { + return *val + } + } + } + return "" +} + +// GetBool gets a boolean value from the result +func (r *WizardResult) GetBool(name string) bool { + if v, ok := r.Values[name]; ok { + switch val := v.(type) { + case bool: + return val + case *bool: + if val != nil { + return *val + } + } + } + return false +} + +// GetInt gets an integer value from the result +func (r *WizardResult) GetInt(name string) int { + if v, ok := r.Values[name]; ok { + switch val := v.(type) { + case int: + return val + case *int: + if val != nil { + return *val + } + case string: + if i, err := strconv.Atoi(val); err == nil { + return i + } + case *string: + if val != nil { + if i, err := strconv.Atoi(*val); err == nil { + return i + } + } + } + } + return 0 +} + +// GetStrings gets a string slice from the result +func (r *WizardResult) GetStrings(name string) []string { + if v, ok := r.Values[name]; ok { + switch val := v.(type) { + case []string: + return val + case *[]string: + if val != nil { + return *val + } + } + } + return nil +} + +// ToMap returns all values as a map +func (r *WizardResult) ToMap() map[string]interface{} { + result := make(map[string]interface{}) + for k, v := range r.Values { + // Dereference pointers + switch val := v.(type) { + case *string: + if val != nil { + result[k] = *val + } else { + result[k] = "" + } + case *bool: + if val != nil { + result[k] = *val + } else { + result[k] = false + } + case *int: + if val != nil { + result[k] = *val + } else { + result[k] = 0 + } + case *[]string: + if val != nil { + result[k] = *val + } else { + result[k] = []string{} + } + default: + result[k] = v + } + } + return result +} diff --git a/client/wizard/runner.go b/client/wizard/runner.go new file mode 100644 index 00000000..ed004c06 --- /dev/null +++ b/client/wizard/runner.go @@ -0,0 +1,279 @@ +package wizard + +import ( + "fmt" + "strconv" + "strings" + + "github.com/charmbracelet/huh" + "github.com/charmbracelet/lipgloss" +) + +// Runner handles the execution of wizards +type Runner struct { + wizard *Wizard + theme *huh.Theme +} + +// NewRunner creates a new runner for the given wizard +func NewRunner(w *Wizard) *Runner { + return &Runner{ + wizard: w, + theme: huh.ThemeCharm(), + } +} + +// WithTheme sets a custom theme +func (r *Runner) WithTheme(theme *huh.Theme) *Runner { + r.theme = theme + return r +} + +// Run executes the wizard and returns the result +func (r *Runner) Run() (*WizardResult, error) { + result := NewWizardResult(r.wizard.ID) + + if r.wizard.IsGrouped() { + return r.runGrouped(result) + } + + // For non-grouped wizards, create a single group with all fields + formFields := make([]*FormField, 0, len(r.wizard.Fields)) + for _, f := range r.wizard.Fields { + ff := r.wizardFieldToFormField(f, result) + formFields = append(formFields, ff) + } + + formGroups := []*FormGroup{{ + Name: "main", + Title: r.wizard.Title, + Fields: formFields, + }} + + groupedForm := NewGroupedWizardForm(formGroups).WithTheme(r.theme) + if err := groupedForm.Run(); err != nil { + return nil, err + } + + r.finalizeResult(result) + return result, nil +} + +// runGrouped runs the wizard using GroupedWizardForm with Tab navigation +func (r *Runner) runGrouped(result *WizardResult) (*WizardResult, error) { + formGroups := make([]*FormGroup, 0, len(r.wizard.Groups)) + + for _, wg := range r.wizard.Groups { + formFields := make([]*FormField, 0, len(wg.Fields)) + + for _, f := range wg.Fields { + ff := r.wizardFieldToFormField(f, result) + formFields = append(formFields, ff) + } + + formGroups = append(formGroups, &FormGroup{ + Name: wg.Name, + Title: wg.Title, + Description: wg.Description, + Fields: formFields, + Optional: wg.Optional, + Expanded: wg.Expanded, + }) + } + + groupedForm := NewGroupedWizardForm(formGroups).WithTheme(r.theme) + + if err := groupedForm.Run(); err != nil { + return nil, err + } + + r.finalizeResult(result) + return result, nil +} + +// RunTwoPhase executes the wizard (kept for backward compatibility) +func (r *Runner) RunTwoPhase() (*WizardResult, error) { + return r.Run() +} + +// wizardFieldToFormField converts a WizardField to a FormField +func (r *Runner) wizardFieldToFormField(f *WizardField, result *WizardResult) *FormField { + ff := &FormField{ + Name: f.Name, + Title: f.Title, + Description: f.Description, + Required: f.Required, + Validate: f.Validate, + } + + switch f.Type { + case FieldSelect: + ff.Kind = KindSelect + ff.Options = f.Options + val := "" + if f.Default != nil { + val = fmt.Sprintf("%v", f.Default) + } + // Auto-select first non-empty option if default is empty + if val == "" && len(f.Options) > 0 { + for _, opt := range f.Options { + if opt != "" && opt != "(empty)" { + val = opt + break + } + } + // Fallback to first option if all are empty + if val == "" { + val = f.Options[0] + } + } + for i, opt := range f.Options { + if opt == val { + ff.Selected = i + break + } + } + result.Values[f.Name] = &val + ff.Value = &val + + case FieldMultiSelect: + ff.Kind = KindMultiSelect + ff.Options = f.Options + var vals []string + if f.Default != nil { + if defaults, ok := f.Default.([]string); ok { + vals = defaults + } + } + ff.MultiSelect = make(map[int]bool) + for _, v := range vals { + for i, opt := range f.Options { + if opt == v { + ff.MultiSelect[i] = true + break + } + } + } + result.Values[f.Name] = &vals + ff.Value = &vals + + case FieldConfirm: + ff.Kind = KindConfirm + val := false + if f.Default != nil { + if b, ok := f.Default.(bool); ok { + val = b + } + } + ff.ConfirmVal = val + result.Values[f.Name] = &val + ff.Value = &val + + case FieldInput, FieldText, FieldFilePath: + ff.Kind = KindInput + val := "" + if f.Default != nil { + val = fmt.Sprintf("%v", f.Default) + } + ff.InputValue = val + result.Values[f.Name] = &val + ff.Value = &val + + case FieldNumber: + ff.Kind = KindNumber + val := "" + if f.Default != nil { + switch v := f.Default.(type) { + case int: + val = strconv.Itoa(v) + case string: + val = v + } + } + ff.InputValue = val + result.Values[f.Name] = &val + ff.Value = &val + } + + return ff +} + +func (r *Runner) finalizeResult(result *WizardResult) { + for _, f := range r.wizard.Fields { + if f.Type != FieldNumber { + continue + } + raw, ok := result.Values[f.Name] + if !ok { + continue + } + switch val := raw.(type) { + case *string: + if val == nil { + result.Values[f.Name] = 0 + continue + } + s := strings.TrimSpace(*val) + if s == "" { + result.Values[f.Name] = 0 + continue + } + if n, err := strconv.Atoi(s); err == nil { + result.Values[f.Name] = n + } + case string: + s := strings.TrimSpace(val) + if s == "" { + result.Values[f.Name] = 0 + continue + } + if n, err := strconv.Atoi(s); err == nil { + result.Values[f.Name] = n + } + } + } +} + +// SelectOption represents an option in a select menu +type SelectOption struct { + Value string + Label string + Description string +} + +// RunSelect displays an interactive select menu and returns the selected value +func RunSelect(title string, options []SelectOption) (string, error) { + // Prevent lipgloss from sending OSC terminal queries (like \x1b]11;?) + // which can conflict with readline's input handling and cause garbled output. + lipglossInitOnce.Do(func() { + lipgloss.SetHasDarkBackground(true) + }) + + if len(options) == 0 { + return "", fmt.Errorf("no options provided") + } + + selected := options[0].Value + + huhOptions := make([]huh.Option[string], len(options)) + for i, opt := range options { + label := opt.Label + if opt.Description != "" { + label = fmt.Sprintf("%-12s - %s", opt.Label, opt.Description) + } + huhOptions[i] = huh.NewOption(label, opt.Value) + } + + selectField := huh.NewSelect[string](). + Title(title). + Options(huhOptions...). + Value(&selected) + + form := huh.NewForm(huh.NewGroup(selectField)).WithTheme(huh.ThemeCharm()) + + if err := form.Run(); err != nil { + return "", err + } + + return selected, nil +} diff --git a/client/wizard/spec.go b/client/wizard/spec.go new file mode 100644 index 00000000..f7f2e9a2 --- /dev/null +++ b/client/wizard/spec.go @@ -0,0 +1,321 @@ +package wizard + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "reflect" + "strconv" + "strings" + + "github.com/chainreactors/malice-network/helper/intl" + "gopkg.in/yaml.v3" +) + +// WizardSpec is a serializable wizard definition (JSON/YAML) for building reusable templates. +type WizardSpec struct { + ID string `json:"id" yaml:"id"` + Title string `json:"title" yaml:"title"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Fields []FieldSpec `json:"fields,omitempty" yaml:"fields,omitempty"` // Flat fields (legacy/simple) + Groups []GroupSpec `json:"groups,omitempty" yaml:"groups,omitempty"` // Grouped fields (new) +} + +// GroupSpec is a serializable group definition. +type GroupSpec struct { + Name string `json:"name" yaml:"name"` + Title string `json:"title" yaml:"title"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Optional bool `json:"optional,omitempty" yaml:"optional,omitempty"` + Fields []FieldSpec `json:"fields" yaml:"fields"` +} + +// FieldSpec is a serializable field definition. +type FieldSpec struct { + Name string `json:"name" yaml:"name"` + Title string `json:"title" yaml:"title"` + Description string `json:"description,omitempty" yaml:"description,omitempty"` + Type string `json:"type" yaml:"type"` + Default any `json:"default,omitempty" yaml:"default,omitempty"` + Options []string `json:"options,omitempty" yaml:"options,omitempty"` + Required bool `json:"required,omitempty" yaml:"required,omitempty"` +} + +// SpecFromMap converts a generic map (e.g. from Lua) into a WizardSpec. +func SpecFromMap(spec map[string]interface{}) (*WizardSpec, error) { + if spec == nil { + return nil, errors.New("spec is nil") + } + data, err := json.Marshal(spec) + if err != nil { + return nil, err + } + var out WizardSpec + if err := json.Unmarshal(data, &out); err != nil { + return nil, err + } + return &out, nil +} + +// LoadSpec loads a WizardSpec from a JSON/YAML file. +func LoadSpec(path string) (*WizardSpec, error) { + data, err := readSpecBytes(path) + if err != nil { + return nil, err + } + + var spec WizardSpec + switch strings.ToLower(filepath.Ext(path)) { + case ".json": + if err := json.Unmarshal(data, &spec); err != nil { + return nil, err + } + case ".yaml", ".yml": + if err := yaml.Unmarshal(data, &spec); err != nil { + return nil, err + } + default: + // Try YAML first for better UX (it is a superset for many configs). + if err := yaml.Unmarshal(data, &spec); err != nil { + if err2 := json.Unmarshal(data, &spec); err2 != nil { + return nil, fmt.Errorf("unsupported spec format (expected .json/.yaml/.yml): %w", err) + } + } + } + + return &spec, nil +} + +func readSpecBytes(path string) ([]byte, error) { + if strings.HasPrefix(path, "embed://") { + return intl.ReadEmbedResource(path) + } + return os.ReadFile(path) +} + +// NewWizardFromFile loads a WizardSpec and builds a Wizard instance. +func NewWizardFromFile(path string) (*Wizard, error) { + spec, err := LoadSpec(path) + if err != nil { + return nil, err + } + return NewWizardFromSpec(spec) +} + +// NewWizardFromSpec builds a Wizard from a WizardSpec. +func NewWizardFromSpec(spec *WizardSpec) (*Wizard, error) { + if spec == nil { + return nil, errors.New("spec is nil") + } + if strings.TrimSpace(spec.ID) == "" { + return nil, errors.New("spec.id is required") + } + + wiz := NewWizard(spec.ID, spec.Title).WithDescription(spec.Description) + + // Check if spec uses groups + if len(spec.Groups) > 0 { + for i, gs := range spec.Groups { + if strings.TrimSpace(gs.Name) == "" { + return nil, fmt.Errorf("groups[%d].name is required", i) + } + if strings.TrimSpace(gs.Title) == "" { + return nil, fmt.Errorf("groups[%d].title is required", i) + } + + group := wiz.NewGroup(gs.Name, gs.Title).WithDescription(gs.Description) + if gs.Optional { + group.AsOptional() + } + + for j, fs := range gs.Fields { + field, err := parseFieldSpec(fs, fmt.Sprintf("groups[%d].fields[%d]", i, j)) + if err != nil { + return nil, err + } + group.AddField(field) + } + } + } else if len(spec.Fields) > 0 { + // Legacy flat fields (backward compatible) + for i, fs := range spec.Fields { + field, err := parseFieldSpec(fs, fmt.Sprintf("fields[%d]", i)) + if err != nil { + return nil, err + } + wiz.AddField(field) + } + } + + return wiz, nil +} + +// parseFieldSpec parses a single FieldSpec into a WizardField +func parseFieldSpec(fs FieldSpec, path string) (*WizardField, error) { + if strings.TrimSpace(fs.Name) == "" { + return nil, fmt.Errorf("%s.name is required", path) + } + if strings.TrimSpace(fs.Title) == "" { + return nil, fmt.Errorf("%s.title is required", path) + } + + ft, err := parseFieldTypeName(fs.Type) + if err != nil { + return nil, fmt.Errorf("%s.type: %w", path, err) + } + + field := &WizardField{ + Name: fs.Name, + Title: fs.Title, + Description: fs.Description, + Type: ft, + Options: append([]string(nil), fs.Options...), + Required: fs.Required, + } + + if fs.Default != nil { + switch ft { + case FieldConfirm: + b, err := coerceBool(fs.Default) + if err != nil { + return nil, fmt.Errorf("%s.default: %w", path, err) + } + field.Default = b + case FieldNumber: + n, err := coerceInt(fs.Default) + if err != nil { + return nil, fmt.Errorf("%s.default: %w", path, err) + } + field.Default = n + case FieldMultiSelect: + ss, err := coerceStrings(fs.Default) + if err != nil { + return nil, fmt.Errorf("%s.default: %w", path, err) + } + field.Default = ss + default: + field.Default = fmt.Sprintf("%v", fs.Default) + } + } + + if ft == FieldSelect || ft == FieldMultiSelect { + if len(field.Options) == 0 { + return nil, fmt.Errorf("%s.options is required for %s", path, fs.Type) + } + } + + return field, nil +} + +// RegisterTemplateFromSpec registers a template backed by a WizardSpec. +func RegisterTemplateFromSpec(name string, spec *WizardSpec) error { + if strings.TrimSpace(name) == "" { + return errors.New("template name is required") + } + if spec == nil { + return errors.New("spec is nil") + } + + specCopy := *spec + if strings.TrimSpace(specCopy.ID) == "" { + specCopy.ID = name + } + + wiz, err := NewWizardFromSpec(&specCopy) + if err != nil { + return err + } + + RegisterTemplate(name, func() *Wizard { return wiz.Clone() }) + return nil +} + +func parseFieldTypeName(name string) (FieldType, error) { + switch strings.ToLower(strings.TrimSpace(name)) { + case "input": + return FieldInput, nil + case "text": + return FieldText, nil + case "select": + return FieldSelect, nil + case "multiselect", "multi_select", "multi-select": + return FieldMultiSelect, nil + case "confirm": + return FieldConfirm, nil + case "number", "int", "integer": + return FieldNumber, nil + case "filepath", "file_path", "file-path": + return FieldFilePath, nil + default: + return 0, fmt.Errorf("unknown field type: %q", name) + } +} + +func coerceBool(v any) (bool, error) { + switch val := v.(type) { + case bool: + return val, nil + case string: + switch strings.ToLower(strings.TrimSpace(val)) { + case "1", "true", "yes", "y", "on": + return true, nil + case "0", "false", "no", "n", "off": + return false, nil + default: + return false, fmt.Errorf("invalid bool: %q", val) + } + default: + return false, fmt.Errorf("invalid bool type: %T", v) + } +} + +func coerceInt(v any) (int, error) { + if v == nil { + return 0, fmt.Errorf("invalid int: nil") + } + + // Handle string specially for parsing + if s, ok := v.(string); ok { + n, err := strconv.Atoi(strings.TrimSpace(s)) + if err != nil { + return 0, fmt.Errorf("invalid int: %q", s) + } + return n, nil + } + + // Use reflect for all numeric types + rv := reflect.ValueOf(v) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return int(rv.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int(rv.Uint()), nil + case reflect.Float32, reflect.Float64: + return int(rv.Float()), nil + default: + return 0, fmt.Errorf("invalid int type: %T", v) + } +} + +func coerceStrings(v any) ([]string, error) { + switch val := v.(type) { + case []string: + out := make([]string, len(val)) + copy(out, val) + return out, nil + case []interface{}: + out := make([]string, 0, len(val)) + for i, item := range val { + s, ok := item.(string) + if !ok { + return nil, fmt.Errorf("invalid string at index %d: %T", i, item) + } + out = append(out, s) + } + return out, nil + default: + return nil, fmt.Errorf("invalid []string type: %T", v) + } +} diff --git a/client/wizard/spec_test.go b/client/wizard/spec_test.go new file mode 100644 index 00000000..465c68d7 --- /dev/null +++ b/client/wizard/spec_test.go @@ -0,0 +1,194 @@ +package wizard + +import ( + "testing" + + lua "github.com/yuin/gopher-lua" +) + +func TestNewWizardFromSpec(t *testing.T) { + spec := &WizardSpec{ + ID: "spec_wizard", + Title: "Spec Wizard", + Description: "From spec", + Fields: []FieldSpec{ + {Name: "host", Title: "Host", Type: "input", Default: "0.0.0.0", Required: true}, + {Name: "protocol", Title: "Protocol", Type: "select", Options: []string{"tcp", "http"}, Default: "http"}, + {Name: "modules", Title: "Modules", Type: "multiselect", Options: []string{"a", "b"}, Default: []interface{}{"a"}}, + {Name: "tls", Title: "TLS", Type: "confirm", Default: true}, + {Name: "port", Title: "Port", Type: "number", Default: float64(443)}, + }, + } + + wiz, err := NewWizardFromSpec(spec) + if err != nil { + t.Fatalf("NewWizardFromSpec failed: %v", err) + } + if wiz.ID != "spec_wizard" { + t.Fatalf("expected ID spec_wizard, got %q", wiz.ID) + } + if wiz.Description != "From spec" { + t.Fatalf("expected description %q, got %q", "From spec", wiz.Description) + } + if len(wiz.Fields) != 5 { + t.Fatalf("expected 5 fields, got %d", len(wiz.Fields)) + } + if wiz.Fields[0].Required != true { + t.Fatalf("expected host.required=true") + } + if _, ok := wiz.Fields[4].Default.(int); !ok { + t.Fatalf("expected port.default to be int, got %T", wiz.Fields[4].Default) + } +} + +func TestRegisterTemplateFromSpec_SetsID(t *testing.T) { + name := "spec_template_test" + t.Cleanup(func() { + templatesMu.Lock() + delete(Templates, name) + templatesMu.Unlock() + }) + + spec := &WizardSpec{ + Title: "Template", + Fields: []FieldSpec{ + {Name: "x", Title: "X", Type: "input"}, + }, + } + if err := RegisterTemplateFromSpec(name, spec); err != nil { + t.Fatalf("RegisterTemplateFromSpec failed: %v", err) + } + + wiz, ok := GetTemplate(name) + if !ok || wiz == nil { + t.Fatalf("expected to load registered template") + } + if wiz.ID != name { + t.Fatalf("expected template wizard ID to default to %q, got %q", name, wiz.ID) + } +} + +func TestNewWizardFromFile_EmbedPath(t *testing.T) { + wiz, err := NewWizardFromFile("embed://community/resources/testdata/wizard_spec_test.yaml") + if err != nil { + t.Fatalf("NewWizardFromFile(embed://) failed: %v", err) + } + if wiz.ID != "embed_wizard_spec_test" { + t.Fatalf("expected ID embed_wizard_spec_test, got %q", wiz.ID) + } + if wiz.Title == "" { + t.Fatalf("expected non-empty title") + } + if len(wiz.Fields) != 2 { + t.Fatalf("expected 2 fields, got %d", len(wiz.Fields)) + } + if !wiz.Fields[0].Required { + t.Fatalf("expected host.required=true") + } + if wiz.Fields[1].Type != FieldNumber { + t.Fatalf("expected port.type=FieldNumber, got %v", wiz.Fields[1].Type) + } + if _, ok := wiz.Fields[1].Default.(int); !ok { + t.Fatalf("expected port.default to be int, got %T", wiz.Fields[1].Default) + } +} + +func TestLuaWizardBuilderOptionsAndOrder(t *testing.T) { + L := lua.NewState() + defer L.Close() + + SetupMetatable(L) + fns := make(map[string]lua.LGFunction) + RegisterLuaFunctions(fns) + for name, fn := range fns { + L.SetGlobal(name, L.NewFunction(fn)) + } + + if err := L.DoString(` +wiz = wizard("lua_wiz", "Lua Wizard") +wiz:input("host", "Host", { required = true, desc = "host desc" }) +wiz:select("protocol", "Protocol", {"tcp", "http", "https"}, "http", { required = true }) +wiz:multiselect("mods", "Modules", {"a", "b"}, {"a"}, { required = true }) +`); err != nil { + t.Fatalf("lua script failed: %v", err) + } + + ud, ok := L.GetGlobal("wiz").(*lua.LUserData) + if !ok { + t.Fatalf("expected wiz userdata, got %T", L.GetGlobal("wiz")) + } + wiz, ok := ud.Value.(*Wizard) + if !ok { + t.Fatalf("expected *Wizard userdata value, got %T", ud.Value) + } + + if len(wiz.Fields) != 3 { + t.Fatalf("expected 3 fields, got %d", len(wiz.Fields)) + } + if wiz.Fields[0].Description != "host desc" || !wiz.Fields[0].Required { + t.Fatalf("expected input opts applied, got desc=%q required=%v", wiz.Fields[0].Description, wiz.Fields[0].Required) + } + if got := wiz.Fields[1].Options; len(got) != 3 || got[0] != "tcp" || got[1] != "http" || got[2] != "https" { + t.Fatalf("expected select option order preserved, got %#v", got) + } + if wiz.Fields[2].Default == nil { + t.Fatalf("expected multiselect defaults set") + } +} + +func TestLuaWizardFromSpecAndRegisterTemplate(t *testing.T) { + L := lua.NewState() + defer L.Close() + + SetupMetatable(L) + fns := make(map[string]lua.LGFunction) + RegisterLuaFunctions(fns) + for name, fn := range fns { + L.SetGlobal(name, L.NewFunction(fn)) + } + + name := "lua_spec_template_test" + t.Cleanup(func() { + templatesMu.Lock() + delete(Templates, name) + templatesMu.Unlock() + }) + + if err := L.DoString(` +wiz2 = wizard_from_spec({ + id = "from_spec", + title = "From Spec", + fields = { + { name = "port", title = "Port", type = "number", default = 80, required = true }, + }, +}) +wizard_register_template("` + name + `", { + title = "Registered Template", + fields = { + { name = "host", title = "Host", type = "input", default = "127.0.0.1" }, + }, +}) +`); err != nil { + t.Fatalf("lua script failed: %v", err) + } + + ud, ok := L.GetGlobal("wiz2").(*lua.LUserData) + if !ok { + t.Fatalf("expected wiz2 userdata, got %T", L.GetGlobal("wiz2")) + } + wiz, ok := ud.Value.(*Wizard) + if !ok { + t.Fatalf("expected *Wizard userdata value, got %T", ud.Value) + } + if wiz.ID != "from_spec" || len(wiz.Fields) != 1 { + t.Fatalf("unexpected wizard from spec: id=%q fields=%d", wiz.ID, len(wiz.Fields)) + } + + templ, ok := GetTemplate(name) + if !ok || templ == nil { + t.Fatalf("expected registered template") + } + if templ.ID != name { + t.Fatalf("expected template ID %q, got %q", name, templ.ID) + } +} diff --git a/client/wizard/templates.go b/client/wizard/templates.go new file mode 100644 index 00000000..76345b82 --- /dev/null +++ b/client/wizard/templates.go @@ -0,0 +1,406 @@ +package wizard + +import ( + "sort" + "sync" +) + +// WizardCategory represents a group of related wizards +type WizardCategory struct { + Name string // Category command name (e.g., "build") + Title string // Display title + Description string // Category description + Wizards []WizardEntry // Wizards in this category +} + +// WizardEntry represents a wizard within a category +type WizardEntry struct { + ID string // Short ID within category (e.g., "beacon") + FullID string // Full template ID (e.g., "build_beacon") + Description string // Description for selection menu +} + +// Categories defines the wizard groupings +var Categories = []WizardCategory{ + { + Name: "build", + Title: "Build", + Description: "Build implants and payloads", + Wizards: []WizardEntry{ + {ID: "beacon", FullID: "build_beacon", Description: "Build a beacon implant with full options"}, + {ID: "pulse", FullID: "build_pulse", Description: "Build stage-0 shellcode"}, + {ID: "prelude", FullID: "build_prelude", Description: "Build multi-stage payload"}, + {ID: "module", FullID: "build_module", Description: "Build custom module DLL"}, + }, + }, + { + Name: "pipeline", + Title: "Pipeline", + Description: "Configure communication pipelines", + Wizards: []WizardEntry{ + {ID: "tcp", FullID: "tcp_pipeline", Description: "Configure a TCP pipeline"}, + {ID: "http", FullID: "http_pipeline", Description: "Configure an HTTP pipeline"}, + {ID: "bind", FullID: "bind_pipeline", Description: "Configure a bind pipeline"}, + {ID: "rem", FullID: "rem_pipeline", Description: "Configure a REM pipeline"}, + }, + }, + { + Name: "cert", + Title: "Certificate", + Description: "Manage TLS certificates", + Wizards: []WizardEntry{ + {ID: "generate", FullID: "cert_generate", Description: "Generate a self-signed certificate"}, + {ID: "import", FullID: "cert_import", Description: "Import an existing certificate"}, + }, + }, + { + Name: "config", + Title: "Config", + Description: "Configure external services", + Wizards: []WizardEntry{ + {ID: "github", FullID: "github_config", Description: "Configure GitHub Actions build"}, + {ID: "notify", FullID: "notify_config", Description: "Configure notification channels"}, + }, + }, +} + +// StandaloneWizards are wizards that don't belong to a category +var StandaloneWizards = []WizardEntry{ + {ID: "listener", FullID: "listener_setup", Description: "Configure a new listener"}, + {ID: "profile", FullID: "profile_create", Description: "Create a new implant profile"}, + {ID: "infra", FullID: "infrastructure_setup", Description: "One-stop C2 infrastructure setup"}, +} + +// Templates is a registry of predefined wizard templates +var Templates = map[string]func() *Wizard{ + // Existing + "listener_setup": NewListenerSetupWizard, + "tcp_pipeline": NewTCPPipelineWizard, + "http_pipeline": NewHTTPPipelineWizard, + "profile_create": NewProfileCreateWizard, + // Build + "build_beacon": NewBuildBeaconWizard, + "build_pulse": NewBuildPulseWizard, + "build_prelude": NewBuildPreludeWizard, + "build_module": NewBuildModuleWizard, + // Pipeline + "bind_pipeline": NewBindPipelineWizard, + "rem_pipeline": NewRemPipelineWizard, + // Certificate + "cert_generate": NewCertGenerateWizard, + "cert_import": NewCertImportWizard, + // Config + "github_config": NewGithubConfigWizard, + "notify_config": NewNotifyConfigWizard, + // Composite + "infrastructure_setup": NewInfrastructureSetupWizard, +} + +// GetCategory returns a category by name +func GetCategory(name string) *WizardCategory { + for i := range Categories { + if Categories[i].Name == name { + return &Categories[i] + } + } + return nil +} + +// GetStandaloneWizard returns a standalone wizard by ID +func GetStandaloneWizard(id string) *WizardEntry { + for i := range StandaloneWizards { + if StandaloneWizards[i].ID == id { + return &StandaloneWizards[i] + } + } + return nil +} + +var templatesMu sync.RWMutex + +// NewListenerSetupWizard creates a wizard for listener configuration +func NewListenerSetupWizard() *Wizard { + return NewWizard("listener_setup", "Listener Setup"). + WithDescription("Configure a new listener"). + Input("name", "Listener Name", "").Field().Require().Desc("Unique identifier for this listener").End(). + Input("host", "Host Address", "0.0.0.0").Field().WithValidate(ValidateHost()).Desc("IP to bind, 0.0.0.0 = all interfaces").End(). + Select("protocol", "Protocol", []string{"tcp", "http", "https"}).Field().Require().Desc("Communication protocol type").End(). + Number("port", "Port", 443).Field().WithValidate(ValidatePort()).Desc("Port number (443=HTTPS, 80=HTTP, etc.)").End(). + Confirm("tls", "Enable TLS?", true).Field().Desc("Enable TLS encryption for secure communication").End() +} + +// NewTCPPipelineWizard creates a wizard for TCP pipeline setup +func NewTCPPipelineWizard() *Wizard { + return NewWizard("tcp_pipeline", "TCP Pipeline Setup"). + WithDescription("Configure a new TCP pipeline"). + Input("name", "Pipeline Name", "").Field().Desc("Unique identifier for this pipeline").End(). + Select("listener_id", "Listener ID", []string{""}).Field().Require().Desc("Parent listener to attach this pipeline").End(). + Input("host", "Host Address", "0.0.0.0").Field().WithValidate(ValidateHost()).Desc("IP to bind for implant connections").End(). + Number("port", "Port", 5001).Field().WithValidate(ValidatePort()).Desc("TCP port for implant callbacks").End(). + Confirm("tls", "Enable TLS?", false).Field().Desc("Enable TLS encryption on TCP connection").End() +} + +// NewHTTPPipelineWizard creates a wizard for HTTP pipeline setup +func NewHTTPPipelineWizard() *Wizard { + return NewWizard("http_pipeline", "HTTP Pipeline Setup"). + WithDescription("Configure a new HTTP pipeline"). + Input("name", "Pipeline Name", "").Field().Desc("Unique identifier for this pipeline").End(). + Select("listener_id", "Listener ID", []string{""}).Field().Require().Desc("Parent listener to attach this pipeline").End(). + Input("host", "Host Address", "0.0.0.0").Field().WithValidate(ValidateHost()).Desc("IP to bind for HTTP requests").End(). + Number("port", "Port", 443).Field().WithValidate(ValidatePort()).Desc("HTTP port (443 for HTTPS, 80 for HTTP)").End(). + Confirm("tls", "Enable TLS?", true).Field().Desc("Enable HTTPS (recommended for production)").End() +} + +// NewProfileCreateWizard creates a wizard for profile creation +func NewProfileCreateWizard() *Wizard { + return NewWizard("profile_create", "Create Profile"). + WithDescription("Create a new implant profile"). + Input("name", "Profile Name", "").Field().Require().Desc("Unique identifier for this profile").End(). + Select("pipeline", "Pipeline ID", []string{""}).Field().Require().Desc("Pipeline for C2 communication").End(). + Select("type", "Implant Type", []string{"beacon", "bind", "prelude"}).Field().Require().Desc("beacon=persistent, bind=reverse, prelude=staged").End(). + MultiSelect("modules", "Modules", []string{ + "base", + "sys_full", + "execute_exe", + "execute_dll", + "execute_bof", + "execute_shellcode", + "execute_assembly", + }).Field().Desc("Select modules to include in the implant").End() +} + +// NewBuildBeaconWizard creates a wizard for building beacon implant +func NewBuildBeaconWizard() *Wizard { + return NewWizard("build_beacon", "Build Beacon"). + WithDescription("Build a beacon implant with full options"). + // Group 1: Basic Configuration + NewGroup("basic", "Basic Configuration"). + WithDescription("Core build settings"). + Select("profile", "Profile Name", []string{""}).Field().Desc("Implant profile with pipeline settings").EndGroup(). + Select("target", "Build Target", []string{ + "x86_64-pc-windows-gnu", "i686-pc-windows-gnu", + "x86_64-pc-windows-msvc", "i686-pc-windows-msvc", + "x86_64-unknown-linux-musl", "i686-unknown-linux-musl", + "x86_64-apple-darwin", "aarch64-apple-darwin", + }).Field().Require().Desc("Target OS and architecture").EndGroup(). + Select("source", "Build Source", []string{"docker", "action", "saas"}).Field().Require().Desc("docker=local, action=GitHub, saas=cloud").EndGroup(). + Confirm("lib", "Build as Library (DLL/SO)?", false).Field().Desc("Build as DLL/SO instead of executable").EndGroup(). + End(). + // Group 2: Network Configuration + NewGroup("network", "Network Configuration"). + WithDescription("C2 connection settings"). + Select("addresses", "C2 Addresses", []string{""}).Field().Require().Desc("Server addresses for callbacks").EndGroup(). + Input("proxy", "Proxy URL", "").Field().Desc("HTTP proxy URL (optional)").EndGroup(). + Confirm("proxy_use_env", "Use Environment Proxy?", false).Field().Desc("Use system proxy settings").EndGroup(). + End(). + // Group 3: Communication Parameters + NewGroup("timing", "Communication Parameters"). + WithDescription("Timing and retry settings"). + Input("cron", "Cron Expression", "*/5 * * * * * *").Field().Desc("Callback schedule (*/5 = every 5 sec)").EndGroup(). + Input("jitter", "Jitter (0.0-1.0)", "0.2").Field().WithValidate(ValidateFloat(0, 1)).Desc("Random delay factor (0.0-1.0)").EndGroup(). + Number("init_retry", "Initial Retry Count", 3).Field().WithValidate(ValidateRange(0, 100)).Desc("Retry count on initial connection").EndGroup(). + Number("server_retry", "Server Retry Count", 3).Field().WithValidate(ValidateRange(0, 100)).Desc("Retry count per server address").EndGroup(). + Number("global_retry", "Global Retry Count", 3).Field().WithValidate(ValidateRange(0, 100)).Desc("Total retry count before exit").EndGroup(). + End(). + // Group 4: Encryption Configuration + NewGroup("crypto", "Encryption Configuration"). + WithDescription("Traffic encryption settings"). + Select("encryption", "Encryption Type", []string{"", "aes", "xor"}).Field().Desc("Traffic encryption method").EndGroup(). + Input("key", "Encryption Key (empty for auto)", "").Field().Desc("Custom key or empty for auto-generate").EndGroup(). + Confirm("secure", "Enable Secure Mode?", false).Field().Desc("Enhanced security features").EndGroup(). + End(). + // Group 5: Module Selection + NewGroup("modules", "Module Selection"). + WithDescription("Select modules to include"). + MultiSelect("modules", "Modules", []string{ + "nano", "full", "base", "extend", + "fs_full", "sys_full", "execute_full", "net_full", + }).Field().Desc("Built-in module packages to include").EndGroup(). + MultiSelect("third_modules", "3rd Party Modules", []string{"rem", "curl"}).Field().Desc("Additional third-party modules").EndGroup(). + End(). + // Group 6: Protection Configuration + NewGroup("protection", "Protection Configuration"). + WithDescription("Anti-analysis and guardrails"). + Confirm("anti_sandbox", "Enable Anti-Sandbox?", false).Field().Desc("Detect and evade sandbox environments").EndGroup(). + Input("guardrail_ips", "Guardrail IPs (comma-separated)", "").Field().Desc("Only run if target has these IPs").EndGroup(). + Input("guardrail_users", "Guardrail Usernames", "").Field().Desc("Only run for these usernames").EndGroup(). + Input("guardrail_servers", "Guardrail Server Names", "").Field().Desc("Only run on these server names").EndGroup(). + Input("guardrail_domains", "Guardrail Domains", "").Field().Desc("Only run in these domains").EndGroup(). + Confirm("ollvm", "Enable OLLVM Obfuscation?", false).Field().Desc("Apply OLLVM code obfuscation").EndGroup(). + End() +} + +// NewBuildPulseWizard creates a wizard for building pulse shellcode +func NewBuildPulseWizard() *Wizard { + return NewWizard("build_pulse", "Build Pulse"). + WithDescription("Build stage-0 shellcode"). + Select("target", "Build Target", []string{ + "x86_64-pc-windows-gnu", "i686-pc-windows-gnu", + "x86_64-pc-windows-msvc", "i686-pc-windows-msvc", + }).Field().Require().Desc("Target OS and architecture").End(). + Select("source", "Build Source", []string{"docker", "action", "saas"}).Field().Require().Desc("docker=local, action=GitHub, saas=cloud").End(). + Select("profile", "Profile Name", []string{""}).Field().Desc("Implant profile with pipeline settings").End(). + Select("address", "C2 Address", []string{""}).Field().Require().Desc("Server address for stage-1 download").End(). + Input("user_agent", "User-Agent", "").Field().Desc("Custom User-Agent header for HTTP requests").End(). + Select("beacon_artifact_id", "Beacon Artifact ID", []string{""}).Field().Desc("Pre-built beacon artifact to download").End(). + Input("path", "HTTP Path", "/pulse").Field().Desc("HTTP path for stage-1 download").End() +} + +// NewBuildPreludeWizard creates a wizard for building multi-stage payload +func NewBuildPreludeWizard() *Wizard { + return NewWizard("build_prelude", "Build Prelude"). + WithDescription("Build multi-stage payload"). + Select("target", "Build Target", []string{ + "x86_64-pc-windows-gnu", "i686-pc-windows-gnu", + "x86_64-pc-windows-msvc", "i686-pc-windows-msvc", + "x86_64-unknown-linux-musl", "i686-unknown-linux-musl", + "x86_64-apple-darwin", "aarch64-apple-darwin", + }).Field().Require().Desc("Target OS and architecture").End(). + FilePath("autorun", "Autorun ZIP File").Field().Desc("ZIP file with autorun scripts").End(). + Select("profile", "Profile Name", []string{""}).Field().Desc("Implant profile with pipeline settings").End(). + Select("source", "Build Source", []string{"docker", "action", "saas"}).Field().Require().Desc("docker=local, action=GitHub, saas=cloud").End() +} + +// NewBuildModuleWizard creates a wizard for building custom module DLL +func NewBuildModuleWizard() *Wizard { + return NewWizard("build_module", "Build Module"). + WithDescription("Build custom module DLL"). + Select("target", "Build Target", []string{ + "x86_64-pc-windows-gnu", "i686-pc-windows-gnu", + "x86_64-pc-windows-msvc", "i686-pc-windows-msvc", + }).Field().Require().Desc("Target OS and architecture").End(). + MultiSelect("modules", "Modules", []string{ + "nano", "full", "base", "extend", + "fs_full", "sys_full", "execute_full", "net_full", + }).Field().Desc("Built-in module packages to include").End(). + MultiSelect("third_modules", "3rd Party Modules", []string{"rem", "curl"}).Field().Desc("Additional third-party modules").End(). + Select("profile", "Profile Name", []string{""}).Field().Desc("Implant profile with pipeline settings").End(). + Select("source", "Build Source", []string{"docker", "action", "saas"}).Field().Require().Desc("docker=local, action=GitHub, saas=cloud").End() +} + +// NewBindPipelineWizard creates a wizard for bind pipeline setup +func NewBindPipelineWizard() *Wizard { + return NewWizard("bind_pipeline", "Bind Pipeline Setup"). + WithDescription("Configure a bind pipeline"). + Select("listener_id", "Listener ID", []string{""}).Field().Require().Desc("Parent listener for bind connections").End() +} + +// NewRemPipelineWizard creates a wizard for REM pipeline setup +func NewRemPipelineWizard() *Wizard { + return NewWizard("rem_pipeline", "REM Pipeline Setup"). + WithDescription("Configure a REM pipeline"). + Input("name", "Pipeline Name", "").Field().Desc("Unique identifier for this pipeline").End(). + Select("listener_id", "Listener ID", []string{""}).Field().Require().Desc("Parent listener to attach").End(). + Input("console", "Console URL (tcp://host:port)", "tcp://0.0.0.0:19966").Field().Desc("Remote console URL for REM module").End(). + Confirm("secure", "Enable Secure Mode?", false).Field().Desc("Enable encryption for REM traffic").End() +} + +// NewCertGenerateWizard creates a wizard for certificate generation +func NewCertGenerateWizard() *Wizard { + return NewWizard("cert_generate", "Generate Certificate"). + WithDescription("Generate a self-signed certificate"). + Input("cn", "Common Name (CN)", "").Field().Require().Desc("Domain name for the certificate").End(). + Input("o", "Organization (O)", "").Field().Desc("Organization name (optional)").End(). + Input("c", "Country (C)", "").Field().Desc("Two-letter country code (e.g., US, CN)").End(). + Input("l", "Locality (L)", "").Field().Desc("City or locality name").End(). + Input("ou", "Organizational Unit (OU)", "").Field().Desc("Department or division name").End(). + Input("st", "State/Province (ST)", "").Field().Desc("State or province name").End(). + Number("validity", "Validity (Days)", 365).Field().WithValidate(ValidateRange(1, 3650)).Desc("Certificate validity period in days").End() +} + +// NewCertImportWizard creates a wizard for certificate import +func NewCertImportWizard() *Wizard { + return NewWizard("cert_import", "Import Certificate"). + WithDescription("Import an existing certificate"). + FilePath("cert", "Certificate File").Field().Require().Desc("Path to certificate file (.crt, .pem)").End(). + FilePath("key", "Private Key File").Field().Require().Desc("Path to private key file (.key, .pem)").End(). + FilePath("ca_cert", "CA Certificate (optional)").Field().Desc("Path to CA certificate if needed").End() +} + +// NewGithubConfigWizard creates a wizard for GitHub Actions configuration +func NewGithubConfigWizard() *Wizard { + return NewWizard("github_config", "GitHub Configuration"). + WithDescription("Configure GitHub Actions build"). + Input("owner", "GitHub Owner/Org", "").Field().Require().Desc("GitHub username or organization").End(). + Input("repo", "Repository Name", "").Field().Require().Desc("Repository name for Actions builds").End(). + Input("token", "GitHub Token", "").Field().Require().Desc("Personal access token with repo scope").End(). + Input("workflow_file", "Workflow File", "").Field().Desc("Custom workflow filename (optional)").End() +} + +// NewNotifyConfigWizard creates a wizard for notification configuration +func NewNotifyConfigWizard() *Wizard { + return NewWizard("notify_config", "Notification Configuration"). + WithDescription("Configure notification channels"). + // Telegram + Confirm("telegram_enable", "Enable Telegram?", false).Field().Desc("Enable Telegram bot notifications").End(). + Input("telegram_token", "Telegram Bot Token", "").Field().Desc("Bot token from @BotFather").End(). + Input("telegram_chat_id", "Telegram Chat ID", "").Field().Desc("Chat/Channel ID to send messages").End(). + // DingTalk + Confirm("dingtalk_enable", "Enable DingTalk?", false).Field().Desc("Enable DingTalk robot notifications").End(). + Input("dingtalk_token", "DingTalk Token", "").Field().Desc("Robot access token").End(). + Input("dingtalk_secret", "DingTalk Secret", "").Field().Desc("Robot signing secret").End(). + // Lark + Confirm("lark_enable", "Enable Lark?", false).Field().Desc("Enable Lark/Feishu notifications").End(). + Input("lark_webhook", "Lark Webhook URL", "").Field().Desc("Webhook URL from Lark bot").End(). + // ServerChan + Confirm("serverchan_enable", "Enable ServerChan?", false).Field().Desc("Enable ServerChan push notifications").End(). + Input("serverchan_url", "ServerChan URL", "").Field().Desc("ServerChan send key URL").End(). + // PushPlus + Confirm("pushplus_enable", "Enable PushPlus?", false).Field().Desc("Enable PushPlus notifications").End(). + Input("pushplus_token", "PushPlus Token", "").Field().Desc("PushPlus user token").End(). + Input("pushplus_topic", "PushPlus Topic", "").Field().Desc("Topic name for group messaging").End() +} + +// NewInfrastructureSetupWizard creates a wizard for one-stop infrastructure setup +func NewInfrastructureSetupWizard() *Wizard { + return NewWizard("infrastructure_setup", "Infrastructure Setup"). + WithDescription("One-stop C2 infrastructure setup"). + // Listener + Input("listener_name", "Listener Name", "").Field().Require().Desc("Unique name for the listener").End(). + Input("listener_host", "Listener Host", "0.0.0.0").Field().WithValidate(ValidateHost()).Desc("IP to bind, 0.0.0.0 = all interfaces").End(). + Select("listener_protocol", "Protocol", []string{"tcp", "http", "https"}).Field().Require().Desc("Communication protocol type").End(). + Number("listener_port", "Listener Port", 443).Field().WithValidate(ValidatePort()).Desc("Port number (443=HTTPS, 80=HTTP)").End(). + Confirm("listener_tls", "Enable TLS?", true).Field().Desc("Enable TLS encryption").End(). + // Pipeline + Select("pipeline_type", "Pipeline Type", []string{"tcp", "http"}).Field().Require().Desc("Pipeline protocol type").End(). + Input("pipeline_name", "Pipeline Name", "").Field().Desc("Unique name for the pipeline").End(). + Input("pipeline_host", "Pipeline Host", "0.0.0.0").Field().WithValidate(ValidateHost()).Desc("IP to bind for implant connections").End(). + Number("pipeline_port", "Pipeline Port", 5001).Field().WithValidate(ValidatePort()).Desc("Port for implant callbacks").End(). + Confirm("pipeline_tls", "Enable Pipeline TLS?", false).Field().Desc("Enable TLS on pipeline connection").End(). + // Profile + Input("profile_name", "Profile Name", "").Field().Require().Desc("Unique name for the profile").End(). + Select("implant_type", "Implant Type", []string{"beacon", "bind", "prelude"}).Field().Require().Desc("beacon=persistent, bind=reverse, prelude=staged").End(). + MultiSelect("modules", "Modules", []string{ + "base", "sys_full", "execute_full", "net_full", + }).Field().Desc("Select modules to include").End() +} + +// GetTemplate returns a wizard template by name +func GetTemplate(name string) (*Wizard, bool) { + templatesMu.RLock() + fn, ok := Templates[name] + templatesMu.RUnlock() + if ok { + return fn(), true + } + return nil, false +} + +// ListTemplates returns all available template names +func ListTemplates() []string { + templatesMu.RLock() + names := make([]string, 0, len(Templates)) + for name := range Templates { + names = append(names, name) + } + templatesMu.RUnlock() + sort.Strings(names) + return names +} + +// RegisterTemplate registers a new wizard template +func RegisterTemplate(name string, factory func() *Wizard) { + templatesMu.Lock() + Templates[name] = factory + templatesMu.Unlock() +} diff --git a/client/wizard/wizard.go b/client/wizard/wizard.go new file mode 100644 index 00000000..febfb4e1 --- /dev/null +++ b/client/wizard/wizard.go @@ -0,0 +1,490 @@ +package wizard + +// FieldType represents the type of a wizard field +type FieldType int + +const ( + FieldInput FieldType = iota + FieldText + FieldSelect + FieldMultiSelect + FieldConfirm + FieldNumber + FieldFilePath +) + +// WizardField represents a single field in the wizard +type WizardField struct { + Name string + Title string + Description string + Type FieldType + Default interface{} + Options []string + Required bool + Validate func(string) error + + // OptionsProvider is called to dynamically populate Options before running + // The ctx parameter is typically *core.Console for accessing RPC + OptionsProvider func(ctx interface{}) []string + + // parent is a reference to the wizard for chaining + parent *Wizard + // groupParent is a reference to the group for chaining (if field belongs to a group) + groupParent *WizardGroup +} + +// SetRequired marks this field as required and returns the wizard for chaining +func (f *WizardField) SetRequired() *Wizard { + f.Required = true + return f.parent +} + +// Require marks this field as required and returns the field for chaining +func (f *WizardField) Require() *WizardField { + f.Required = true + return f +} + +// SetValidate sets a validation function and returns the wizard for chaining +func (f *WizardField) SetValidate(fn func(string) error) *Wizard { + f.Validate = fn + return f.parent +} + +// Validate sets a validation function and returns the field for chaining +func (f *WizardField) WithValidate(fn func(string) error) *WizardField { + f.Validate = fn + return f +} + +// SetRequiredWithValidate marks field as required and sets validation +func (f *WizardField) SetRequiredWithValidate(fn func(string) error) *Wizard { + f.Required = true + f.Validate = fn + return f.parent +} + +// SetDescription sets the field description and returns the wizard for chaining +func (f *WizardField) SetDescription(desc string) *Wizard { + f.Description = desc + return f.parent +} + +// Desc sets the field description and returns the field for chaining +func (f *WizardField) Desc(desc string) *WizardField { + f.Description = desc + return f +} + +// SetOptionsProvider sets a function to dynamically populate options +func (f *WizardField) SetOptionsProvider(provider func(ctx interface{}) []string) *Wizard { + f.OptionsProvider = provider + return f.parent +} + +// End returns the parent wizard for chaining after field configuration +func (f *WizardField) End() *Wizard { + return f.parent +} + +// EndGroup returns the parent group for chaining after field configuration +// Use this when the field was added to a group +func (f *WizardField) EndGroup() *WizardGroup { + return f.groupParent +} + +// WizardGroup represents a logical group of fields +type WizardGroup struct { + Name string // Group identifier (e.g., "basic", "network") + Title string // Display title (e.g., "Basic Settings") + Description string // Group description + Fields []*WizardField // Fields in this group + Optional bool // If true, this group can be skipped (collapsed by default) + Expanded bool // If true and Optional, show fields; otherwise collapsed + parent *Wizard // Reference to parent wizard +} + +// Wizard is the main wizard structure +type Wizard struct { + ID string + Title string + Description string + Fields []*WizardField // Flat list (for backward compatibility) + Groups []*WizardGroup // Grouped fields for pagination +} + +// NewWizard creates a new wizard instance +func NewWizard(id, title string) *Wizard { + return &Wizard{ + ID: id, + Title: title, + Fields: make([]*WizardField, 0), + Groups: make([]*WizardGroup, 0), + } +} + +// WithDescription sets the wizard description +func (w *Wizard) WithDescription(desc string) *Wizard { + w.Description = desc + return w +} + +// AddField adds a field to the wizard +func (w *Wizard) AddField(field *WizardField) *Wizard { + field.parent = w + w.Fields = append(w.Fields, field) + return w +} + +// Field returns the last added field for configuration chaining +func (w *Wizard) Field() *WizardField { + if len(w.Fields) == 0 { + return nil + } + return w.Fields[len(w.Fields)-1] +} + +// IsGrouped returns true if the wizard uses grouped fields +func (w *Wizard) IsGrouped() bool { + return len(w.Groups) > 0 +} + +// NewGroup creates a new group and adds it to the wizard +func (w *Wizard) NewGroup(name, title string) *WizardGroup { + group := &WizardGroup{ + Name: name, + Title: title, + Fields: make([]*WizardField, 0), + parent: w, + } + w.Groups = append(w.Groups, group) + return group +} + +// Group returns a group by name +func (w *Wizard) Group(name string) *WizardGroup { + for _, g := range w.Groups { + if g.Name == name { + return g + } + } + return nil +} + +// WithDescription sets the group description and returns the group for chaining +func (g *WizardGroup) WithDescription(desc string) *WizardGroup { + g.Description = desc + return g +} + +// AsOptional marks this group as optional (collapsed by default) +func (g *WizardGroup) AsOptional() *WizardGroup { + g.Optional = true + g.Expanded = false + return g +} + +// SetExpanded sets the expanded state for optional groups +func (g *WizardGroup) SetExpanded(expanded bool) *WizardGroup { + g.Expanded = expanded + return g +} + +// End returns the parent wizard for switching to another group +func (g *WizardGroup) End() *Wizard { + return g.parent +} + +// AddField adds a field to this group (and also to wizard's flat list) +func (g *WizardGroup) AddField(field *WizardField) *WizardGroup { + field.parent = g.parent + field.groupParent = g + g.Fields = append(g.Fields, field) + g.parent.Fields = append(g.parent.Fields, field) + return g +} + +// Field returns the last added field in this group for configuration chaining +func (g *WizardGroup) Field() *WizardField { + if len(g.Fields) == 0 { + return nil + } + return g.Fields[len(g.Fields)-1] +} + +// Input adds an input field to the group +func (g *WizardGroup) Input(name, title string, defaultVal string) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldInput, + Default: defaultVal, + }) +} + +// InputWithDesc adds an input field with description to the group +func (g *WizardGroup) InputWithDesc(name, title, desc string, defaultVal string) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Description: desc, + Type: FieldInput, + Default: defaultVal, + }) +} + +// Text adds a multi-line text field to the group +func (g *WizardGroup) Text(name, title string, defaultVal string) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldText, + Default: defaultVal, + }) +} + +// Select adds a select field to the group +func (g *WizardGroup) Select(name, title string, options []string) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldSelect, + Options: options, + }) +} + +// SelectWithDefault adds a select field with default value to the group +func (g *WizardGroup) SelectWithDefault(name, title string, options []string, defaultVal string) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldSelect, + Options: options, + Default: defaultVal, + }) +} + +// MultiSelect adds a multi-select field to the group +func (g *WizardGroup) MultiSelect(name, title string, options []string) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldMultiSelect, + Options: options, + }) +} + +// MultiSelectWithDefault adds a multi-select field with default values to the group +func (g *WizardGroup) MultiSelectWithDefault(name, title string, options []string, defaults []string) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldMultiSelect, + Options: options, + Default: defaults, + }) +} + +// Confirm adds a confirm field to the group +func (g *WizardGroup) Confirm(name, title string, defaultVal bool) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldConfirm, + Default: defaultVal, + }) +} + +// Number adds a number input field to the group +func (g *WizardGroup) Number(name, title string, defaultVal int) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldNumber, + Default: defaultVal, + }) +} + +// FilePath adds a file path picker field to the group +func (g *WizardGroup) FilePath(name, title string) *WizardGroup { + return g.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldFilePath, + }) +} + +// GetField returns a field by name +func (w *Wizard) GetField(name string) *WizardField { + for _, f := range w.Fields { + if f.Name == name { + return f + } + } + return nil +} + +// Input adds an input field +func (w *Wizard) Input(name, title string, defaultVal string) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldInput, + Default: defaultVal, + }) +} + +// InputWithDesc adds an input field with description +func (w *Wizard) InputWithDesc(name, title, desc string, defaultVal string) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Description: desc, + Type: FieldInput, + Default: defaultVal, + }) +} + +// Text adds a multi-line text field +func (w *Wizard) Text(name, title string, defaultVal string) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldText, + Default: defaultVal, + }) +} + +// Select adds a select field +func (w *Wizard) Select(name, title string, options []string) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldSelect, + Options: options, + }) +} + +// SelectWithDefault adds a select field with default value +func (w *Wizard) SelectWithDefault(name, title string, options []string, defaultVal string) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldSelect, + Options: options, + Default: defaultVal, + }) +} + +// MultiSelect adds a multi-select field +func (w *Wizard) MultiSelect(name, title string, options []string) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldMultiSelect, + Options: options, + }) +} + +// MultiSelectWithDefault adds a multi-select field with default values +func (w *Wizard) MultiSelectWithDefault(name, title string, options []string, defaults []string) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldMultiSelect, + Options: options, + Default: defaults, + }) +} + +// Confirm adds a confirm field +func (w *Wizard) Confirm(name, title string, defaultVal bool) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldConfirm, + Default: defaultVal, + }) +} + +// Number adds a number input field +func (w *Wizard) Number(name, title string, defaultVal int) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldNumber, + Default: defaultVal, + }) +} + +// FilePath adds a file path picker field +func (w *Wizard) FilePath(name, title string) *Wizard { + return w.AddField(&WizardField{ + Name: name, + Title: title, + Type: FieldFilePath, + }) +} + +// Clone creates a copy of the wizard +func (w *Wizard) Clone() *Wizard { + clone := &Wizard{ + ID: w.ID, + Title: w.Title, + Description: w.Description, + Fields: make([]*WizardField, 0, len(w.Fields)), + Groups: make([]*WizardGroup, 0, len(w.Groups)), + } + + // Build a map from original field to cloned field for group referencing + fieldMap := make(map[*WizardField]*WizardField) + + // Clone all fields + for _, f := range w.Fields { + fieldCopy := *f + fieldCopy.parent = clone + if f.Options != nil { + fieldCopy.Options = append([]string(nil), f.Options...) + } + if defaults, ok := f.Default.([]string); ok { + fieldCopy.Default = append([]string(nil), defaults...) + } + clone.Fields = append(clone.Fields, &fieldCopy) + fieldMap[f] = &fieldCopy + } + + // Clone groups and reference cloned fields + for _, g := range w.Groups { + groupCopy := &WizardGroup{ + Name: g.Name, + Title: g.Title, + Description: g.Description, + Fields: make([]*WizardField, 0, len(g.Fields)), + Optional: g.Optional, + Expanded: g.Expanded, + parent: clone, + } + for _, f := range g.Fields { + if clonedField, ok := fieldMap[f]; ok { + groupCopy.Fields = append(groupCopy.Fields, clonedField) + } + } + clone.Groups = append(clone.Groups, groupCopy) + } + + return clone +} + +// PrepareOptions calls OptionsProvider for all fields that have one, +// populating their Options dynamically. The ctx parameter is passed to providers. +func (w *Wizard) PrepareOptions(ctx interface{}) { + for _, f := range w.Fields { + if f.OptionsProvider != nil { + opts := f.OptionsProvider(ctx) + if len(opts) > 0 { + f.Options = opts + } + } + } +} diff --git a/client/wizard/wizard_test.go b/client/wizard/wizard_test.go new file mode 100644 index 00000000..4d68b6b2 --- /dev/null +++ b/client/wizard/wizard_test.go @@ -0,0 +1,176 @@ +package wizard + +import ( + "testing" +) + +func TestWizardBuilder(t *testing.T) { + wiz := NewWizard("test_wizard", "Test Wizard") + wiz.WithDescription("Test description"). + Input("name", "Enter name", "default_name"). + Select("option", "Select option", []string{"a", "b", "c"}). + Number("count", "Enter count", 10). + Confirm("proceed", "Proceed?", true) + + if wiz.ID != "test_wizard" { + t.Errorf("Expected ID 'test_wizard', got '%s'", wiz.ID) + } + + if wiz.Title != "Test Wizard" { + t.Errorf("Expected Title 'Test Wizard', got '%s'", wiz.Title) + } + + if wiz.Description != "Test description" { + t.Errorf("Expected Description 'Test description', got '%s'", wiz.Description) + } + + if len(wiz.Fields) != 4 { + t.Errorf("Expected 4 fields, got %d", len(wiz.Fields)) + } + + // Check field types + expectedTypes := []FieldType{FieldInput, FieldSelect, FieldNumber, FieldConfirm} + for i, f := range wiz.Fields { + if f.Type != expectedTypes[i] { + t.Errorf("Field %d: Expected type %d, got %d", i, expectedTypes[i], f.Type) + } + } +} + +func TestWizardResult(t *testing.T) { + result := NewWizardResult("test") + + // Test string value + strVal := "test_string" + result.Values["str"] = &strVal + if result.GetString("str") != "test_string" { + t.Errorf("GetString failed") + } + + // Test bool value + boolVal := true + result.Values["bool"] = &boolVal + if !result.GetBool("bool") { + t.Errorf("GetBool failed") + } + + // Test int value (from string) + intStrVal := "42" + result.Values["int"] = &intStrVal + if result.GetInt("int") != 42 { + t.Errorf("GetInt failed, got %d", result.GetInt("int")) + } +} + +func TestWizardTemplates(t *testing.T) { + templates := ListTemplates() + if len(templates) == 0 { + t.Error("Expected some templates, got none") + } + + // Test getting a known template + wiz, ok := GetTemplate("listener_setup") + if !ok { + t.Error("Expected to find 'listener_setup' template") + } + if wiz == nil { + t.Error("Template wizard is nil") + } +} + +func TestAllTemplatesRegistered(t *testing.T) { + expectedTemplates := []string{ + // Existing + "listener_setup", + "tcp_pipeline", + "http_pipeline", + "profile_create", + // Build + "build_beacon", + "build_pulse", + "build_prelude", + "build_module", + // Pipeline + "bind_pipeline", + "rem_pipeline", + // Certificate + "cert_generate", + "cert_import", + // Config + "github_config", + "notify_config", + // Composite + "infrastructure_setup", + } + + templates := ListTemplates() + if len(templates) != len(expectedTemplates) { + t.Errorf("Expected %d templates, got %d", len(expectedTemplates), len(templates)) + } + + for _, name := range expectedTemplates { + wiz, ok := GetTemplate(name) + if !ok { + t.Errorf("Template '%s' not found", name) + continue + } + if wiz == nil { + t.Errorf("Template '%s' is nil", name) + continue + } + if wiz.ID != name { + t.Errorf("Template '%s' has wrong ID: %s", name, wiz.ID) + } + if wiz.Title == "" { + t.Errorf("Template '%s' has empty title", name) + } + if len(wiz.Fields) == 0 { + t.Errorf("Template '%s' has no fields", name) + } + } +} + +func TestWizardClone(t *testing.T) { + wiz := NewWizard("original", "Original"). + Input("field1", "Field 1", ""). + Select("field2", "Field 2", []string{"a", "b"}) + + clone := wiz.Clone() + + if clone.ID != wiz.ID { + t.Error("Clone ID mismatch") + } + + if len(clone.Fields) != len(wiz.Fields) { + t.Error("Clone fields count mismatch") + } + + // Modify clone and ensure original is unchanged + clone.ID = "modified" + if wiz.ID == "modified" { + t.Error("Modifying clone affected original") + } + + // Verify parent rebinding: clone's fields should point to clone, not original + for i, f := range clone.Fields { + if f.parent != clone { + t.Errorf("Clone field %d has wrong parent (expected clone, got original)", i) + } + } + + // Verify original's fields still point to original + for i, f := range wiz.Fields { + if f.parent != wiz { + t.Errorf("Original field %d has wrong parent after clone", i) + } + } + + // Test that chaining on clone affects clone, not original + clone.Field().SetRequired() + if wiz.Fields[len(wiz.Fields)-1].Required { + t.Error("SetRequired on clone affected original") + } + if !clone.Fields[len(clone.Fields)-1].Required { + t.Error("SetRequired on clone did not set Required") + } +} diff --git a/external/console/ai_complete.go b/external/console/ai_complete.go new file mode 100644 index 00000000..f2e4284d --- /dev/null +++ b/external/console/ai_complete.go @@ -0,0 +1,8 @@ +package console + +// Suggestion represents a single AI-generated command suggestion. +// Kept for backwards compatibility with client code. +type Suggestion struct { + Command string // The suggested command + Description string // Description of what the command does +} diff --git a/external/console/console.go b/external/console/console.go index 9fa3f45c..1702dec6 100644 --- a/external/console/console.go +++ b/external/console/console.go @@ -255,3 +255,15 @@ func (c *Console) activeMenu() *Menu { // Else return the default menu. return c.menus[""] } + +// AICommandGenerator is the callback function type for AI-powered command generation. +// It receives the current input line (natural language) and command history, +// and returns a generated command string. +type AICommandGenerator func(line string, history []string) (string, error) + +// SetAICommandGenerator sets the AI command generation callback function. +// When Alt+A is pressed, this callback will be invoked to convert +// natural language input to a command. +func (c *Console) SetAICommandGenerator(fn AICommandGenerator) { + c.shell.AIGenerateCommand = fn +} diff --git a/external/console/go.mod b/external/console/go.mod index ec25231d..f654a17e 100644 --- a/external/console/go.mod +++ b/external/console/go.mod @@ -21,3 +21,5 @@ require ( gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) + +replace github.com/reeflective/readline => ../readline diff --git a/external/readline/ai_prediction_test.go b/external/readline/ai_prediction_test.go new file mode 100644 index 00000000..866fb1e1 --- /dev/null +++ b/external/readline/ai_prediction_test.go @@ -0,0 +1,74 @@ +package readline + +import "testing" + +func TestAIPredictionInsertText(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + currentLine string + prediction string + want string + }{ + { + name: "ArgumentBoundary_PreservesNoExtraSpace", + currentLine: "build beacon --os ", + prediction: "windows", + want: "windows", + }, + { + name: "ArgumentBoundary_TrimsLeadingWhitespace", + currentLine: "build beacon --os ", + prediction: " \twindows", + want: "windows", + }, + { + name: "MidToken_CompletesCommandName", + currentLine: "wiz", + prediction: "wizard", + want: "ard", + }, + { + name: "MidToken_CompletesArgumentValue", + currentLine: "build beacon --os w", + prediction: "windows", + want: "indows", + }, + { + name: "MidToken_ExactMatchYieldsEmpty", + currentLine: "wizard", + prediction: "wizard", + want: "", + }, + { + name: "NextToken_WhenPredictionDoesNotMatchTokenPrefix", + currentLine: "wizard", + prediction: "--help", + want: " --help", + }, + { + name: "TabBoundary_TreatedAsWhitespace", + currentLine: "connect\t", + prediction: "127.0.0.1", + want: "127.0.0.1", + }, + { + name: "EmptyLine_NoLeadingSpace", + currentLine: "", + prediction: "wizard", + want: "wizard", + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := aiPredictionInsertText(tt.currentLine, tt.prediction); got != tt.want { + t.Fatalf("aiPredictionInsertText(%q, %q) = %q, want %q", tt.currentLine, tt.prediction, got, tt.want) + } + }) + } +} diff --git a/external/readline/completion.go b/external/readline/completion.go index 67bb2536..605f00e4 100644 --- a/external/readline/completion.go +++ b/external/readline/completion.go @@ -2,6 +2,8 @@ package readline import ( "fmt" + "strings" + "time" "github.com/reeflective/readline/internal/color" "github.com/reeflective/readline/internal/completion" @@ -31,22 +33,33 @@ func (rl *Shell) completionCommands() commands { // // Attempt completion on the current word. -// Currently identitical to menu-complete. func (rl *Shell) completeWord() { rl.History.SkipSave() + // If there's a local suggestion and we're not in completion menu, accept it + if rl.hasLocalSuggestion() && !rl.completer.IsActive() { + if rl.acceptLocalSuggestion() { + return + } + } + // This completion function should attempt to insert the first // valid completion found, without printing the actual list. if !rl.completer.IsActive() { rl.startMenuComplete(rl.commandCompletion) if rl.Config.GetBool("menu-complete-display-prefix") { + // Trigger local suggestion for next argument + rl.startDelayedLocalSuggestion() return } } rl.completer.Select(1, 0) rl.completer.SkipDisplay() + + // Trigger local suggestion for next argument (after we selected a candidate). + rl.startDelayedLocalSuggestion() } // List possible completions for the current word. @@ -258,3 +271,583 @@ func (rl *Shell) historyCompletion(forward, filterLine, substring bool) { } } } + +// +// AI Prediction (inline ghost text) ------------------------------------------------------- +// + +// startDelayedAIPrediction starts the AI prediction timer +func (rl *Shell) startDelayedAIPrediction() { + if rl.AIPredictNext == nil { + return + } + + // Snapshot line + history on the main loop goroutine. + line, _ := rl.completer.Line() + if line == nil { + return + } + currentLine := string(*line) + if currentLine == "" { + return + } + + var historyLines []string + histSrc := rl.History.Current() + if histSrc != nil { + histLen := histSrc.Len() + start := 0 + if histLen > 20 { + start = histLen - 20 + } + for i := start; i < histLen; i++ { + if cmd, err := histSrc.GetLine(i); err == nil && cmd != "" { + historyLines = append(historyLines, cmd) + } + } + } + + rl.aiPredictionMu.Lock() + defer rl.aiPredictionMu.Unlock() + + // Cancel any existing timer + if rl.aiPredictionTimer != nil { + rl.aiPredictionTimer.Stop() + rl.aiPredictionTimer = nil + } + + rl.aiPredictionSeq++ + seq := rl.aiPredictionSeq + + delay := 250 * time.Millisecond + if strings.HasSuffix(currentLine, " ") || strings.HasSuffix(currentLine, "\t") { + delay = 80 * time.Millisecond + } + + if rl.Config != nil { + if strings.HasSuffix(currentLine, " ") || strings.HasSuffix(currentLine, "\t") { + if ms := rl.Config.GetInt("ai-prediction-delay-space"); ms > 0 { + delay = time.Duration(ms) * time.Millisecond + } + } else { + if ms := rl.Config.GetInt("ai-prediction-delay"); ms > 0 { + delay = time.Duration(ms) * time.Millisecond + } + } + } + + rl.aiPredictionTimer = time.AfterFunc(delay, func() { + rl.triggerAIPrediction(seq, currentLine, historyLines) + }) +} + +// triggerAIPrediction triggers AI-powered prediction asynchronously +func (rl *Shell) triggerAIPrediction(seq uint64, currentLine string, historyLines []string) { + rl.aiPredictionMu.Lock() + if seq != rl.aiPredictionSeq { + rl.aiPredictionMu.Unlock() + return + } + rl.aiPredictionMu.Unlock() + + if currentLine == "" { + return + } + + // Drop stale predictions if the line has changed since scheduling. + line, _ := rl.completer.Line() + if line == nil || string(*line) != currentLine { + return + } + + // Call AI prediction + prediction, err := rl.AIPredictNext(currentLine, historyLines) + if err != nil || prediction == "" { + return + } + + // Drop stale predictions if the line changed while the request was running. + line, _ = rl.completer.Line() + if line == nil || string(*line) != currentLine { + return + } + + // Store prediction + prediction = strings.TrimLeft(prediction, " \t") + insertText := aiPredictionInsertText(currentLine, prediction) + if insertText == "" { + return + } + fullSuggestion := currentLine + insertText + rl.aiPredictionMu.Lock() + if seq != rl.aiPredictionSeq { + rl.aiPredictionMu.Unlock() + return + } + rl.aiPrediction = fullSuggestion + rl.aiPredictionLine = currentLine + rl.aiPredictionMu.Unlock() + + // Show prediction as inline ghost text (fish-style). + rl.SetAISuggestion(fullSuggestion) + + // If the main readline loop is currently blocked waiting for input, + // proactively refresh so the suggestion appears immediately. + // This avoids waiting for the next keypress. + if rl.Keys != nil && rl.Keys.IsWaiting() && !rl.Keys.IsReading() { + if rl.Config != nil && !rl.Config.GetBool("ai-prediction-auto-refresh") { + return + } + + rl.Display.RefreshLine() + } +} + +// acceptAIPrediction accepts the current AI prediction and inserts it +func (rl *Shell) acceptAIPrediction() bool { + rl.aiPredictionMu.Lock() + suggestion := rl.aiPrediction + rl.aiPredictionMu.Unlock() + + if suggestion == "" { + return false + } + + line, _ := rl.completer.Line() + if line == nil { + return false + } + currentLine := string(*line) + if currentLine == "" || !strings.HasPrefix(suggestion, currentLine) || len(suggestion) <= len(currentLine) { + rl.ClearAIPrediction() + return false + } + + // We only render the ghost text at end-of-line; keep acceptance consistent. + if rl.cursor.Pos() != rl.line.Len() { + return false + } + + // Accept any virtually inserted completion candidate and exit completion mode. + completion.UpdateInserted(rl.completer) + + // Insert prediction + suffix := suggestion[len(currentLine):] + for _, r := range suffix { + rl.line.Insert(rl.cursor.Pos(), r) + rl.cursor.Inc() + } + + rl.ClearAIPrediction() + + // Refresh display + rl.Display.Refresh() + + return true +} + +// ClearAIPrediction clears any pending AI prediction +func (rl *Shell) ClearAIPrediction() { + rl.aiPredictionMu.Lock() + defer rl.aiPredictionMu.Unlock() + + // Invalidate any in-flight predictions/timers. + rl.aiPredictionSeq++ + + if rl.aiPredictionTimer != nil { + rl.aiPredictionTimer.Stop() + rl.aiPredictionTimer = nil + } + rl.aiPrediction = "" + rl.aiPredictionLine = "" + + // Clear inline suggestion + rl.Display.ClearAISuggestion() +} + +// GetAIPrediction returns the current AI prediction (for display purposes) +func (rl *Shell) GetAIPrediction() string { + // Only consider predictions when the cursor is at end-of-line, since the + // ghost text is rendered there. + if rl.cursor.Pos() != rl.line.Len() { + return "" + } + + rl.aiPredictionMu.Lock() + suggestion := rl.aiPrediction + rl.aiPredictionMu.Unlock() + if suggestion == "" { + return "" + } + + line, _ := rl.completer.Line() + if line == nil { + return "" + } + currentLine := string(*line) + if currentLine == "" || !strings.HasPrefix(suggestion, currentLine) || len(suggestion) <= len(currentLine) { + return "" + } + + return suggestion[len(currentLine):] +} + +// aiPredictionInsertText returns the suffix to insert/display for a given +// predicted next argument/value and the current input line. +// +// The predictor is instructed to return a single "next argument/value". In +// practice that can also mean a completion of the current token (e.g. when the +// user already typed a prefix). In that case we should only insert the +// remaining suffix rather than adding a new space-delimited token. +func aiPredictionInsertText(currentLine, prediction string) string { + prediction = strings.TrimLeft(prediction, " \t") + if prediction == "" { + return "" + } + if currentLine == "" { + return prediction + } + + // If we're already at an argument boundary, insert the prediction as-is. + lastChar := currentLine[len(currentLine)-1] + if lastChar == ' ' || lastChar == '\t' { + return prediction + } + + // Otherwise, try to interpret the prediction as a completion of the current token. + tokenStart := strings.LastIndexAny(currentLine, " \t") + 1 + if tokenStart < 0 || tokenStart > len(currentLine) { + tokenStart = 0 + } + + tokenPrefix := currentLine[tokenStart:] + if tokenPrefix != "" && strings.HasPrefix(prediction, tokenPrefix) { + return prediction[len(tokenPrefix):] + } + + // Fallback: treat as the next token and add a separating space. + return " " + prediction +} + +// applyPendingAICompletion is kept for compatibility with async completion hooks. +// It currently ensures stale AI predictions are discarded when the input line changes. +func (rl *Shell) applyPendingAICompletion(refresh bool) bool { + // Ensure inline suggestions don't leak when the line is empty. + line, _ := rl.completer.Line() + if line != nil && len(*line) == 0 { + rl.clearLocalSuggestion() + rl.ClearAIPrediction() + if refresh { + rl.Display.Refresh() + } + return true + } + + rl.aiPredictionMu.Lock() + suggestion := rl.aiPrediction + rl.aiPredictionMu.Unlock() + + if suggestion == "" { + return false + } + + line, _ = rl.completer.Line() + if line == nil { + return false + } + + currentLine := string(*line) + + // Clear when the suggestion no longer matches the current input (or when the + // user already typed it all), to avoid stale ghost text reappearing later. + if currentLine == "" || !strings.HasPrefix(suggestion, currentLine) || len(suggestion) <= len(currentLine) { + rl.ClearAIPrediction() + if refresh { + rl.Display.Refresh() + } + + return true + } + + return false +} + +// +// Local Suggestion (fast completion-based suggestions without AI) -------------------------- +// + +// startDelayedLocalSuggestion starts a timer to compute local suggestions after a short delay. +// This provides debouncing to avoid computing suggestions on every keystroke. +func (rl *Shell) startDelayedLocalSuggestion() { + line, _ := rl.completer.Line() + if line == nil || len(*line) == 0 { + rl.clearLocalSuggestion() + return + } + currentLine := string(*line) + + rl.localSuggestionMu.Lock() + defer rl.localSuggestionMu.Unlock() + + // Cancel any existing timer + if rl.localSuggestionTimer != nil { + rl.localSuggestionTimer.Stop() + rl.localSuggestionTimer = nil + } + + rl.localSuggestionSeq++ + seq := rl.localSuggestionSeq + + // Use a short delay (50ms) for debouncing + delay := 50 * time.Millisecond + if strings.HasSuffix(currentLine, " ") || strings.HasSuffix(currentLine, "\t") { + delay = 30 * time.Millisecond // Faster after space + } + + rl.localSuggestionTimer = time.AfterFunc(delay, func() { + rl.computeLocalSuggestion(seq, currentLine) + }) +} + +// computeLocalSuggestion computes and displays a local suggestion. +// Priority: completion candidates > history match +func (rl *Shell) computeLocalSuggestion(seq uint64, currentLine string) { + rl.localSuggestionMu.Lock() + if seq != rl.localSuggestionSeq { + rl.localSuggestionMu.Unlock() + return + } + rl.localSuggestionMu.Unlock() + + // Check if the line has changed since scheduling + line, _ := rl.completer.Line() + if line == nil || string(*line) != currentLine { + return + } + + var suggestion string + + // Priority 1: Try to get suggestion from completion system + suggestion = rl.getCompletionSuggestion(currentLine) + + // Priority 2: If no completion, try history match + if suggestion == "" { + suggestion = rl.getHistorySuggestion(currentLine) + } + + if suggestion == "" || suggestion == currentLine { + rl.clearLocalSuggestion() + if rl.Keys != nil && rl.Keys.IsWaiting() && !rl.Keys.IsReading() { + rl.Display.RefreshLine() + } + return + } + + // Store and display suggestion + rl.localSuggestionMu.Lock() + if seq != rl.localSuggestionSeq { + rl.localSuggestionMu.Unlock() + return + } + rl.localSuggestion = suggestion + rl.localSuggestionLine = currentLine + // Use the existing AI suggestion display mechanism + rl.SetAISuggestion(suggestion) + rl.localSuggestionMu.Unlock() + + // Refresh display if the main loop is waiting for input + if rl.Keys != nil && rl.Keys.IsWaiting() && !rl.Keys.IsReading() { + rl.Display.RefreshLine() + } +} + +// getCompletionSuggestion gets a suggestion from the completion system. +// Returns the first matching candidate or the common prefix of multiple candidates. +func (rl *Shell) getCompletionSuggestion(currentLine string) string { + if rl.Completer == nil { + return "" + } + + // Get completions for current line + line := []rune(currentLine) + cursor := len(line) + comps := rl.Completer(line, cursor) + + if len(comps.values) == 0 { + return "" + } + + // Get the current word prefix + prefix := "" + if comps.PREFIX != "" { + prefix = comps.PREFIX + } else { + // Calculate prefix from word boundary + pos := len(line) - 1 + for pos >= 0 { + c := line[pos] + if c == ' ' || c == '\t' { + break + } + pos-- + } + prefix = string(line[pos+1:]) + } + + // Filter candidates that match the prefix + var matchingValues []string + ignoreCase := rl.Config != nil && rl.Config.GetBool("completion-ignore-case") + for _, v := range comps.values { + value := v.Value + matchPrefix := prefix + if ignoreCase { + value = strings.ToLower(value) + matchPrefix = strings.ToLower(matchPrefix) + } + if strings.HasPrefix(value, matchPrefix) { + matchingValues = append(matchingValues, v.Value) + } + } + + if len(matchingValues) == 0 { + return "" + } + + prefixLen := len([]rune(prefix)) + if prefixLen > len(line) { + return "" + } + + // If only one candidate, return it + if len(matchingValues) == 1 { + // Build full line with completion + lineWithoutPrefix := string(line[:len(line)-prefixLen]) + return lineWithoutPrefix + matchingValues[0] + } + + // Multiple candidates: compute common prefix + commonPrefix := longestCommonPrefix(matchingValues) + if len([]rune(commonPrefix)) > prefixLen { + lineWithoutPrefix := string(line[:len(line)-prefixLen]) + return lineWithoutPrefix + commonPrefix + } + + return "" +} + +// longestCommonPrefix returns the longest common prefix of a slice of strings. +func longestCommonPrefix(strs []string) string { + if len(strs) == 0 { + return "" + } + prefix := []rune(strs[0]) + for _, s := range strs[1:] { + runes := []rune(s) + i := 0 + for i < len(prefix) && i < len(runes) && prefix[i] == runes[i] { + i++ + } + prefix = prefix[:i] + if len(prefix) == 0 { + break + } + } + return string(prefix) +} + +// getHistorySuggestion gets a suggestion from command history. +func (rl *Shell) getHistorySuggestion(currentLine string) string { + suggested := string(rl.History.Suggest(rl.line)) + if suggested == "" || suggested == currentLine { + return "" + } + if !strings.HasPrefix(suggested, currentLine) { + return "" + } + return suggested +} + +// hasLocalSuggestion returns true if there's a valid local suggestion for the current line. +func (rl *Shell) hasLocalSuggestion() bool { + rl.localSuggestionMu.Lock() + defer rl.localSuggestionMu.Unlock() + + if rl.localSuggestion == "" { + return false + } + + line, _ := rl.completer.Line() + if line == nil { + return false + } + currentLine := string(*line) + if currentLine == "" { + return false + } + + return strings.HasPrefix(rl.localSuggestion, currentLine) && + len(rl.localSuggestion) > len(currentLine) +} + +// acceptLocalSuggestion accepts the current local suggestion and inserts it. +func (rl *Shell) acceptLocalSuggestion() bool { + rl.localSuggestionMu.Lock() + suggestion := rl.localSuggestion + rl.localSuggestionMu.Unlock() + + if suggestion == "" { + return false + } + + line, _ := rl.completer.Line() + if line == nil { + return false + } + currentLine := string(*line) + if currentLine == "" { + rl.clearLocalSuggestion() + return false + } + + if !strings.HasPrefix(suggestion, currentLine) || len(suggestion) <= len(currentLine) { + rl.clearLocalSuggestion() + return false + } + + // Only accept when cursor is at end of line + if rl.cursor.Pos() != rl.line.Len() { + return false + } + + // Accept any virtually inserted completion candidate and exit completion mode. + completion.UpdateInserted(rl.completer) + + // Insert the suggestion suffix + suffix := suggestion[len(currentLine):] + for _, r := range suffix { + rl.line.Insert(rl.cursor.Pos(), r) + rl.cursor.Inc() + } + + rl.clearLocalSuggestion() + rl.Display.Refresh() + + return true +} + +// clearLocalSuggestion clears the current local suggestion. +func (rl *Shell) clearLocalSuggestion() { + rl.localSuggestionMu.Lock() + defer rl.localSuggestionMu.Unlock() + + if rl.localSuggestionTimer != nil { + rl.localSuggestionTimer.Stop() + rl.localSuggestionTimer = nil + } + rl.localSuggestionSeq++ + rl.localSuggestion = "" + rl.localSuggestionLine = "" + + // Clear the displayed suggestion + rl.Display.ClearAISuggestion() +} diff --git a/external/readline/emacs.go b/external/readline/emacs.go index 446d530c..d3868e29 100644 --- a/external/readline/emacs.go +++ b/external/readline/emacs.go @@ -165,6 +165,13 @@ func (rl *Shell) emacsEditingMode() { func (rl *Shell) forwardChar() { startPos := rl.cursor.Pos() + // At end of line: accept AI suggestion if available (fish-style) + if rl.cursor.Pos() >= rl.line.Len()-1 && rl.GetAIPrediction() != "" { + if rl.acceptAIPrediction() { + return + } + } + // Only exception where we actually don't forward a character. if rl.Config.GetBool("history-autosuggest") && rl.cursor.Pos() >= rl.line.Len()-1 { rl.autosuggestAccept() @@ -321,6 +328,8 @@ func (rl *Shell) clearDisplay() { func (rl *Shell) endOfFile() { switch rl.line.Len() { case 0: + rl.clearLocalSuggestion() + rl.ClearAIPrediction() rl.Display.AcceptLine() rl.History.Accept(false, false, io.EOF) default: @@ -450,6 +459,9 @@ func (rl *Shell) selfInsert() { rl.cursor.InsertAt(quoted...) rl.cursor.Move(-1 * len(quoted)) rl.cursor.Move(length) + + // Trigger local suggestion after character insertion + rl.startDelayedLocalSuggestion() } func (rl *Shell) bracketedPasteBegin() { @@ -1244,6 +1256,8 @@ func (rl *Shell) abort() { } // If no line was active, + rl.clearLocalSuggestion() + rl.ClearAIPrediction() rl.Display.AcceptLine() rl.History.Accept(false, false, ErrInterrupt) } @@ -1544,6 +1558,8 @@ func (rl *Shell) editAndExecuteCommand() { // Update our line and return it the caller. rl.line.Set(edited...) + rl.clearLocalSuggestion() + rl.ClearAIPrediction() rl.Display.AcceptLine() rl.History.Accept(false, false, nil) } diff --git a/external/readline/history.go b/external/readline/history.go index c6d6bbff..7722487f 100644 --- a/external/readline/history.go +++ b/external/readline/history.go @@ -629,6 +629,10 @@ func (rl *Shell) autosuggestDisable() { // func (rl *Shell) acceptLineWith(infer, hold bool) { + // Stop any pending inline suggestion timers to avoid background refresh after accept. + rl.clearLocalSuggestion() + rl.ClearAIPrediction() + // If we are currently using the incremental-search buffer, // we should cancel this mode so as to run the rest of this // function on (with) the input line itself, not the minibuffer. diff --git a/external/readline/inputrc/config.go b/external/readline/inputrc/config.go index 4872e858..507b7dd7 100644 --- a/external/readline/inputrc/config.go +++ b/external/readline/inputrc/config.go @@ -43,6 +43,19 @@ func NewDefaultConfig(opts ...ConfigOption) *Config { Binds: DefaultBinds(), Funcs: make(map[string]func(string, string) error), } + + // Application/library-specific options (not part of GNU readline/bash inputrc). + // These are added here so they can be configured at runtime (e.g. via `set`). + if _, ok := cfg.Vars["ai-prediction-delay"]; !ok { + cfg.Vars["ai-prediction-delay"] = 250 + } + if _, ok := cfg.Vars["ai-prediction-delay-space"]; !ok { + cfg.Vars["ai-prediction-delay-space"] = 80 + } + if _, ok := cfg.Vars["ai-prediction-auto-refresh"]; !ok { + cfg.Vars["ai-prediction-auto-refresh"] = true + } + for _, o := range opts { o(cfg) } diff --git a/external/readline/internal/completion/engine.go b/external/readline/internal/completion/engine.go index be271e3f..bcdc6f43 100644 --- a/external/readline/internal/completion/engine.go +++ b/external/readline/internal/completion/engine.go @@ -300,6 +300,11 @@ func (e *Engine) Line() (*core.Line, *core.Cursor) { return e.line, e.cursor } +// BaseLine returns the underlying, non-virtually-completed input line and cursor. +func (e *Engine) BaseLine() (*core.Line, *core.Cursor) { + return e.line, e.cursor +} + // Autocomplete generates the correct completions in autocomplete mode. // We don't do it when we are currently in the completion keymap, // since that means completions have already been computed. diff --git a/external/readline/internal/core/keys.go b/external/readline/internal/core/keys.go index 7fe3f13f..f62c11a7 100644 --- a/external/readline/internal/core/keys.go +++ b/external/readline/internal/core/keys.go @@ -41,16 +41,16 @@ type Keys struct { // WaitAvailableKeys waits until an input key is either read from standard input, // or directly returns if the key stack still/already has available keys. -func WaitAvailableKeys(keys *Keys, cfg *inputrc.Config) { +func WaitAvailableKeys(keys *Keys, cfg *inputrc.Config) bool { keys.cfg = cfg if len(keys.buf) > 0 && !keys.mustWait { - return + return false } // The macro engine might have fed some keys if len(keys.macroKeys) > 0 { - return + return false } keys.mutex.Lock() @@ -70,7 +70,7 @@ func WaitAvailableKeys(keys *Keys, cfg *inputrc.Config) { // send by ourselves, because we pause reading. keyBuf, err := keys.readInputFiltered() if err != nil && errors.Is(err, io.EOF) { - return + return false } if len(keyBuf) == 0 { @@ -94,7 +94,7 @@ func WaitAvailableKeys(keys *Keys, cfg *inputrc.Config) { keys.mutex.RUnlock() } - return + return false } } @@ -337,6 +337,14 @@ func (k *Keys) IsReading() bool { return k.reading } +// IsWaiting returns true if the key reader is currently blocked waiting for +// input (i.e. inside WaitAvailableKeys). +func (k *Keys) IsWaiting() bool { + k.mutex.RLock() + defer k.mutex.RUnlock() + return k.waiting +} + func (k *Keys) extractCursorPos(keys []byte) (cursor, remain []byte) { if !rxRcvCursorPos.Match(keys) { return cursor, keys diff --git a/external/readline/internal/core/keys_unix.go b/external/readline/internal/core/keys_unix.go index 901f1961..556fb35f 100644 --- a/external/readline/internal/core/keys_unix.go +++ b/external/readline/internal/core/keys_unix.go @@ -51,19 +51,19 @@ func (k *Keys) GetCursorPos() (x, y int) { return disable() } - // Attempt to locate cursor response in it. - match = rxRcvCursorPos.FindAllStringSubmatch(string(cursor), 1) - - // If there is something but not cursor answer, its user input. - if len(match) == 0 && len(cursor) > 0 { + // Strip any cursor response and push remaining bytes back as user input. + cursorResp, remain := k.extractCursorPos(cursor) + if len(remain) > 0 { k.mutex.RLock() - k.buf = append(k.buf, cursor...) + k.buf = append(k.buf, remain...) k.mutex.RUnlock() + } + if len(cursorResp) == 0 { continue } - // And if empty, then we should abort. + match = rxRcvCursorPos.FindAllStringSubmatch(string(cursorResp), 1) if len(match) == 0 { return disable() } diff --git a/external/readline/internal/display/engine.go b/external/readline/internal/display/engine.go index d88cd10d..64c13a67 100644 --- a/external/readline/internal/display/engine.go +++ b/external/readline/internal/display/engine.go @@ -2,6 +2,7 @@ package display import ( "fmt" + "strings" "github.com/reeflective/readline/inputrc" "github.com/reeflective/readline/internal/color" @@ -33,6 +34,9 @@ type Engine struct { compRows int primaryPrinted bool + // AI inline suggestion (fish-style) + aiSuggestion string + // UI components keys *core.Keys line *core.Line @@ -66,6 +70,21 @@ func Init(e *Engine, highlighter func([]rune) string) { e.highlighter = highlighter } +// SetAISuggestion sets the AI inline suggestion to display after the cursor. +func (e *Engine) SetAISuggestion(suggestion string) { + e.aiSuggestion = suggestion +} + +// ClearAISuggestion clears the AI inline suggestion. +func (e *Engine) ClearAISuggestion() { + e.aiSuggestion = "" +} + +// GetAISuggestion returns the current AI inline suggestion. +func (e *Engine) GetAISuggestion() string { + return e.aiSuggestion +} + // Refresh recomputes and redisplays the entire readline interface, except // the first lines of the primary prompt when the latter is a multiline one. func (e *Engine) Refresh() { @@ -98,6 +117,22 @@ func (e *Engine) Refresh() { fmt.Print(term.ShowCursor) } +// RefreshHelpers redraws the hint and completion sections without querying the +// terminal for a new cursor position. It relies on the last computed display +// coordinates, so it should only be used when the input line hasn't changed. +func (e *Engine) RefreshHelpers() { + fmt.Print(term.HideCursor) + + // Go to the last line of input (displayHelpers assumes that). + term.MoveCursorUp(e.cursorRow) + term.MoveCursorDown(e.lineRows) + + e.displayHelpers() + e.cursorHintToLineStart() + e.lineStartToCursorPos() + fmt.Print(term.ShowCursor) +} + // PrintPrimaryPrompt redraws the primary prompt. // There are relatively few cases where you want to use this. // It is currently only used when using clear-screen commands. @@ -222,16 +257,58 @@ func (e *Engine) computeCoordinates(suggested bool) { e.cursorCol, e.cursorRow = core.CoordinatesCursor(e.cursor, e.startCols) - // Get the number of rows used by the line, and the end line X pos. - if e.opts.GetBool("history-autosuggest") && suggested { - e.lineCol, e.lineRows = core.CoordinatesLine(&e.suggested, e.startCols) - } else { - e.lineCol, e.lineRows = core.CoordinatesLine(e.line, e.startCols) + // Get the number of rows used by the line (including inline suggestions), and the end line X pos. + lineForCoords := e.line + + cursorAtEOL := e.cursor != nil && e.line != nil && e.cursor.Pos() == e.line.Len() + currentLine := string(*e.line) + if suggested && cursorAtEOL && e.aiSuggestion != "" && strings.HasPrefix(e.aiSuggestion, currentLine) && len(e.aiSuggestion) > len(currentLine) { + aiLine := core.Line([]rune(e.aiSuggestion)) + lineForCoords = &aiLine + } else if suggested && e.opts.GetBool("history-autosuggest") { + lineForCoords = &e.suggested } + e.lineCol, e.lineRows = core.CoordinatesLine(lineForCoords, e.startCols) + e.primaryPrinted = false } +// computeCoordinatesNoCursorQuery recomputes display coordinates without querying the +// terminal for cursor position (and without reprinting the prompt). It relies on the +// last known prompt indentation and starting row, so it should only be used when the +// prompt hasn't changed and the cursor is still within the current input line. +func (e *Engine) computeCoordinatesNoCursorQuery(suggested bool) { + // Get the new input line and auto-suggested one. + e.line, e.cursor = e.completer.Line() + if e.completer.IsInserting() { + e.suggested = *e.line + } else { + e.suggested = e.histories.Suggest(e.line) + } + + // If we don't have valid indentation coordinates, fall back to 0. + if e.startCols < 0 { + e.startCols = 0 + } + + e.cursorCol, e.cursorRow = core.CoordinatesCursor(e.cursor, e.startCols) + + // Get the number of rows used by the line (including inline suggestions), and the end line X pos. + lineForCoords := e.line + + cursorAtEOL := e.cursor != nil && e.line != nil && e.cursor.Pos() == e.line.Len() + currentLine := string(*e.line) + if suggested && cursorAtEOL && e.aiSuggestion != "" && strings.HasPrefix(e.aiSuggestion, currentLine) && len(e.aiSuggestion) > len(currentLine) { + aiLine := core.Line([]rune(e.aiSuggestion)) + lineForCoords = &aiLine + } else if suggested && e.opts.GetBool("history-autosuggest") { + lineForCoords = &e.suggested + } + + e.lineCol, e.lineRows = core.CoordinatesLine(lineForCoords, e.startCols) +} + func (e *Engine) displayLine() { var line string @@ -251,8 +328,14 @@ func (e *Engine) displayLine() { // Apply visual selections highlighting if any line = e.highlightLine([]rune(line), *e.selection) - // Get the subset of the suggested line to print. - if len(e.suggested) > e.line.Len() && e.opts.GetBool("history-autosuggest") { + // AI inline suggestion (fish-style, higher priority) + currentLine := string(*e.line) + cursorAtEOL := e.cursor != nil && e.line != nil && e.cursor.Pos() == e.line.Len() + if cursorAtEOL && e.aiSuggestion != "" && strings.HasPrefix(e.aiSuggestion, currentLine) && len(e.aiSuggestion) > len(currentLine) { + suffix := e.aiSuggestion[len(currentLine):] + line += color.Dim + color.Fmt(color.Fg+"242") + suffix + color.Reset + } else if len(e.suggested) > e.line.Len() && e.opts.GetBool("history-autosuggest") { + // Get the subset of the suggested line to print (history autosuggest) line += color.Dim + color.Fmt(color.Fg+"242") + string(e.suggested[e.line.Len():]) + color.Reset } @@ -270,6 +353,29 @@ func (e *Engine) displayLine() { } } +// RefreshLine redraws the input line and helper sections without reprinting the prompt +// or querying the terminal for cursor position. This is useful when refreshing from +// a background goroutine (e.g. async AI suggestions) while the main loop is waiting +// for user input. +func (e *Engine) RefreshLine() { + fmt.Print(term.HideCursor) + + // Move to the input line start (just after the prompt) and clear everything below. + e.CursorToLineStart() + fmt.Print(term.ClearScreenBelow) + + // Recompute coordinates based on the last known prompt offset. + e.computeCoordinatesNoCursorQuery(true) + + // Redraw line and helpers, then restore cursor. + e.displayLine() + e.displayMultilinePrompts() + e.displayHelpers() + e.cursorHintToLineStart() + e.lineStartToCursorPos() + fmt.Print(term.ShowCursor) +} + func (e *Engine) displayMultilinePrompts() { // If we have more than one line, write the columns. if e.line.Lines() > 1 { diff --git a/external/readline/readline.go b/external/readline/readline.go index d510a11b..422c294b 100644 --- a/external/readline/readline.go +++ b/external/readline/readline.go @@ -79,6 +79,9 @@ func (rl *Shell) Readline() (string, error) { // been consumed but did not match any command. core.FlushUsed(rl.Keys) + // Apply any async AI completion/prediction results before redisplay. + rl.applyPendingAICompletion(false) + // Since we always update helpers after being asked to read // for user input again, we do it before actually reading it. rl.Display.Refresh() @@ -127,6 +130,9 @@ func (rl *Shell) Readline() (string, error) { // init gathers all steps to perform at the beginning of readline loop. func (rl *Shell) init() { + rl.clearLocalSuggestion() + rl.ClearAIPrediction() + // Reset core editor components. core.FlushUsed(rl.Keys) rl.line.Set() diff --git a/external/readline/shell.go b/external/readline/shell.go index 428e2a41..17e5816e 100644 --- a/external/readline/shell.go +++ b/external/readline/shell.go @@ -2,6 +2,8 @@ package readline import ( "fmt" + "sync" + "time" "github.com/reeflective/readline/inputrc" "github.com/reeflective/readline/internal/completion" @@ -61,6 +63,30 @@ type Shell struct { // It takes the readline line ([]rune) and cursor pos as parameters, // and returns completions with their associated metadata/settings. Completer func(line []rune, cursor int) Completions + + // AIGenerateCommand is a callback function that converts natural language + // input to a command. Used by the ai-generate-command action (Alt+Q). + AIGenerateCommand func(line string, history []string) (string, error) + + // AISmartComplete is a callback for AI-powered smart completion. + // Returns multiple command suggestions for Tab completion. + AISmartComplete func(line string, history []string) ([]string, error) + + // AI prediction state (for inline ghost text) + AIPredictNext func(line string, history []string) (string, error) // Predicts next argument + aiPrediction string // Current prediction text + aiPredictionLine string // Line snapshot for the current prediction + aiPredictionMu sync.Mutex + aiPredictionTimer *time.Timer + aiPredictionSeq uint64 // Monotonically increasing, used to drop stale requests. + lastTabTime time.Time // For double-tab detection + + // Local suggestion state (fast completion-based suggestions without AI) + localSuggestion string // Current local suggestion text (full line) + localSuggestionLine string // Line snapshot when suggestion was computed + localSuggestionMu sync.Mutex // Protects local suggestion state + localSuggestionTimer *time.Timer // Debounce timer + localSuggestionSeq uint64 // Sequence number to drop stale requests } // NewShell returns a readline shell instance initialized with a default @@ -137,6 +163,21 @@ func (rl *Shell) Cursor() *core.Cursor { return rl.cursor } // selections used to change/select multiple parts of the line at once. func (rl *Shell) Selection() *core.Selection { return rl.selection } +// SetAISuggestion sets an AI inline suggestion to display after the cursor (fish-style). +func (rl *Shell) SetAISuggestion(suggestion string) { + rl.Display.SetAISuggestion(suggestion) +} + +// ClearAISuggestion clears the AI inline suggestion. +func (rl *Shell) ClearAISuggestion() { + rl.Display.ClearAISuggestion() +} + +// GetAISuggestion returns the current AI inline suggestion. +func (rl *Shell) GetAISuggestion() string { + return rl.Display.GetAISuggestion() +} + // Printf prints a formatted string below the current line and redisplays the prompt // and input line (and possibly completions/hints if active) below the logged string. // A newline is added to the message so that the prompt is correctly refreshed below. diff --git a/external/readline/vim.go b/external/readline/vim.go index 84423183..1e2250b8 100644 --- a/external/readline/vim.go +++ b/external/readline/vim.go @@ -238,6 +238,13 @@ func (rl *Shell) viAddEol() { // Move forward one character, without changing lines. func (rl *Shell) viForwardChar() { + // At end of line: accept AI suggestion if available (fish-style) + if rl.cursor.Pos() == rl.line.Len()-1 && rl.GetAIPrediction() != "" { + if rl.acceptAIPrediction() { + return + } + } + // Only exception where we actually don't forward a character. if rl.Config.GetBool("history-autosuggest") && rl.cursor.Pos() == rl.line.Len()-1 { rl.autosuggestAccept() diff --git a/go.mod b/go.mod index 924b321a..a260c6fe 100644 --- a/go.mod +++ b/go.mod @@ -13,8 +13,10 @@ require ( github.com/chainreactors/rem v0.2.4 github.com/chainreactors/tui v0.0.0-20250825071101-9e61744e554f github.com/chainreactors/utils v0.0.0-20241209140746-65867d2f78b2 - github.com/charmbracelet/bubbletea v1.3.4 + github.com/charmbracelet/bubbletea v1.3.6 github.com/charmbracelet/glamour v0.8.0 + github.com/charmbracelet/huh v0.8.0 + github.com/charmbracelet/lipgloss v1.1.0 github.com/corpix/uarand v0.2.0 github.com/dustin/go-humanize v1.0.1 github.com/evertras/bubble-table v0.17.2 @@ -51,12 +53,12 @@ require ( github.com/docker/distribution v2.8.2+incompatible // indirect github.com/docker/docker v24.0.9+incompatible golang.org/x/crypto v0.33.0 - golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63 + golang.org/x/exp v0.0.0-20231006140011-7918f672742d golang.org/x/net v0.40.0 // indirect - golang.org/x/sync v0.14.0 // indirect + golang.org/x/sync v0.15.0 // indirect golang.org/x/sys v0.33.0 // indirect golang.org/x/term v0.29.0 // indirect - golang.org/x/text v0.22.0 + golang.org/x/text v0.23.0 golang.org/x/time v0.9.0 // indirect google.golang.org/grpc v1.57.2 google.golang.org/protobuf v1.34.1 @@ -73,16 +75,17 @@ require ( github.com/aymerick/douceur v0.2.0 // indirect github.com/blinkbean/dingtalk v1.1.3 // indirect github.com/carapace-sh/carapace-shlex v1.0.1 // indirect + github.com/catppuccin/go v0.3.0 // indirect github.com/cbroglie/mustache v1.4.0 // indirect github.com/chainreactors/fingers v0.0.0-20240702104653-a66e34aa41df // indirect github.com/chainreactors/go-metrics v0.0.0-20220926021830-24787b7a10f8 // indirect github.com/chainreactors/proxyclient v1.0.2 // indirect - github.com/charmbracelet/bubbles v0.20.0 // indirect + github.com/charmbracelet/bubbles v0.21.1-0.20250623103423-23b8fd6302d7 // indirect github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect github.com/charmbracelet/harmonica v0.2.0 // indirect - github.com/charmbracelet/lipgloss v1.1.0 // indirect - github.com/charmbracelet/x/ansi v0.8.0 // indirect - github.com/charmbracelet/x/cellbuf v0.0.13-0.20250311204145-2c3ea96c31dd // indirect + github.com/charmbracelet/x/ansi v0.9.3 // indirect + github.com/charmbracelet/x/cellbuf v0.0.13 // indirect + github.com/charmbracelet/x/exp/strings v0.0.0-20240722160745-212f7b056ed0 // indirect github.com/charmbracelet/x/term v0.2.1 // indirect github.com/cjoudrey/gluahttp v0.0.0-20201111170219-25003d9adfa9 // indirect github.com/davecgh/go-spew v1.1.1 // indirect @@ -120,7 +123,7 @@ require ( github.com/mattn/go-runewidth v0.0.16 // indirect github.com/mattn/go-sqlite3 v1.14.24 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect - github.com/miekg/dns v1.1.67 // indirect + github.com/mitchellh/hashstructure/v2 v2.0.2 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/moby/term v0.5.0 // indirect github.com/montanaflynn/stats v0.7.1 // indirect @@ -155,8 +158,6 @@ require ( github.com/yuin/gluamapper v0.0.0-20150323120927-d836955830e7 // indirect github.com/yuin/goldmark v1.7.4 // indirect github.com/yuin/goldmark-emoji v1.0.3 // indirect - golang.org/x/mod v0.24.0 // indirect - golang.org/x/tools v0.33.0 // indirect golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20230530153820-e85fd2cbaebc // indirect gotest.tools/v3 v3.5.1 // indirect diff --git a/helper/cryptography/cryptography.go b/helper/cryptography/cryptography.go index 82609eed..ac58e259 100644 --- a/helper/cryptography/cryptography.go +++ b/helper/cryptography/cryptography.go @@ -2,12 +2,10 @@ package cryptography import ( "bytes" - "crypto/ed25519" "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/base64" - "encoding/json" "errors" "filippo.io/age" "fmt" @@ -41,6 +39,16 @@ var ( // and we can ensure there will only ever be a single recipient, // we can just ignore add/remove it at runtime to safe space. agePrefix = []byte("age-encryption.org/v1\n-> X25519 ") + + cachedAgeServerKeyPairOnce sync.Once + cachedAgeServerKeyPair *AgeKeyPair + cachedAgeServerKeyPairErr error + + cachedMinisignServerKeyOnce sync.Once + cachedMinisignServerKey *minisign.PrivateKey + cachedMinisignServerKeyErr error + + ageKeyExReplay sync.Map ) // deriveKeyFrom - Derives a key from input data using SHA256 @@ -121,7 +129,7 @@ func AgeEncrypt(recipientPublicKey string, plaintext []byte) ([]byte, error) { if err := stream.Close(); err != nil { return nil, err } - return buf.Bytes(), nil + return bytes.TrimPrefix(buf.Bytes(), agePrefix), nil } // AgeDecrypt - Decrypt using Curve 25519 + ChaCha20Poly1305 @@ -132,11 +140,19 @@ func AgeDecrypt(recipientPrivateKey string, ciphertext []byte) ([]byte, error) { return nil, err } - // 直接使用 ciphertext,Age 库会自动处理 grease recipients + // Accept both full age ciphertext and ciphertext with the standard prefix trimmed. + if !bytes.HasPrefix(ciphertext, []byte("age-encryption.org/v1")) { + prefixed := make([]byte, 0, len(agePrefix)+len(ciphertext)) + prefixed = append(prefixed, agePrefix...) + prefixed = append(prefixed, ciphertext...) + ciphertext = prefixed + } + + // Age library automatically handles grease recipients buf := bytes.NewBuffer(ciphertext) stream, err := age.Decrypt(buf, identity) if err != nil { - // 如果解密失败,尝试添加调试信息 + // If decryption fails, add debug info return nil, fmt.Errorf("age decrypt failed (ciphertext size: %d bytes): %w", len(ciphertext), err) } @@ -150,11 +166,8 @@ func AgeDecrypt(recipientPrivateKey string, ciphertext []byte) ([]byte, error) { // AgeKeyPairFromImplant - Decrypt the session key from an implant func AgeKeyExFromImplant(serverPrivateKey string, implantPrivateKey string, ciphertext []byte) ([]byte, error) { - // TODO - Store the hash of the implant's key exchange to prevent replay attacks - // Check for replay attacks - //if err := db.CheckKeyExReplay(ciphertext); err != nil { - // return nil, ErrDecryptFailed - //} + ciphertextDigest := sha256.Sum256(ciphertext) + b64Digest := base64.RawStdEncoding.EncodeToString(ciphertextDigest[:]) // Decrypt the message plaintext, err := AgeDecrypt(serverPrivateKey, ciphertext) @@ -176,6 +189,10 @@ func AgeKeyExFromImplant(serverPrivateKey string, implantPrivateKey string, ciph if !hmac.Equal(mac.Sum(nil), plaintext[:sha256Size]) { return nil, ErrDecryptFailed } + + if _, ok := ageKeyExReplay.LoadOrStore(b64Digest, true); ok { + return nil, ErrReplayAttack + } return plaintext[sha256Size:], nil } @@ -272,45 +289,14 @@ func serverSignRawBuf(buf []byte) []byte { // AgeServerKeyPair - Get teh server's ECC key pair func AgeServerKeyPair() *AgeKeyPair { - // TODO - get key value from db - //data, err := db.GetKeyValue(serverAgeKeyPairKey) - // test - data, err := json.Marshal(&AgeKeyPair{}) - //if err == db.ErrRecordNotFound { - // keyPair, err := generateServerKeyPair() - // if err != nil { - // panic(err) - // } - // return keyPair - //} - keyPair := &AgeKeyPair{} - err = json.Unmarshal([]byte(data), keyPair) - if err != nil { - panic(err) - } - return keyPair -} - -func generateServerKeyPair() (*AgeKeyPair, error) { - keyPair, err := RandomAgeKeyPair() - if err != nil { - return nil, err - } - data, err := json.Marshal(keyPair) - // test - data = []byte(string(data)) - if err != nil { - return nil, err + // TODO: load from persistent storage; fall back to an in-memory keypair. + cachedAgeServerKeyPairOnce.Do(func() { + cachedAgeServerKeyPair, cachedAgeServerKeyPairErr = RandomAgeKeyPair() + }) + if cachedAgeServerKeyPairErr != nil { + panic(cachedAgeServerKeyPairErr) } - // TODO - set key value in db - // err = db.SetKeyValue(serverAgeKeyPairKey, string(data)) - return keyPair, err -} - -// minisignPrivateKey - This is here so we can marshal to/from JSON -type minisignPrivateKey struct { - ID uint64 `json:"id"` - PrivateKey []byte `json:"private_key"` + return cachedAgeServerKeyPair } // MinisignServerPublicKey - Get the server's minisign public key string @@ -331,47 +317,19 @@ func MinisignServerSign(message []byte) string { // MinisignServerPrivateKey - Get the server's minisign key pair func MinisignServerPrivateKey() *minisign.PrivateKey { - // TODO - get key value from db - // test - data, err := json.Marshal(&AgeKeyPair{}) - //data, err := db.GetKeyValue(serverMinisignPrivateKey) - //if err == db.ErrRecordNotFound { - // privateKey, err := generateServerMinisignPrivateKey() - // if err != nil { - // panic(err) - // } - // return privateKey - //} - privateKey := &minisignPrivateKey{} - err = json.Unmarshal([]byte(data), privateKey) - if err != nil { - panic(err) - } - rawBytes := [ed25519.PrivateKeySize]byte{} - copy(rawBytes[:], privateKey.PrivateKey) - return &minisign.PrivateKey{ - RawID: privateKey.ID, - RawBytes: rawBytes, - } -} - -func generateServerMinisignPrivateKey() (*minisign.PrivateKey, error) { - _, privateKey, err := minisign.GenerateKey(rand.Reader) - if err != nil { - return nil, err - } - data, _ := json.Marshal(&minisignPrivateKey{ - ID: privateKey.ID(), - PrivateKey: privateKey.Bytes(), + // TODO: load from persistent storage; fall back to an in-memory keypair. + cachedMinisignServerKeyOnce.Do(func() { + _, privateKey, err := minisign.GenerateKey(rand.Reader) + if err != nil { + cachedMinisignServerKeyErr = err + return + } + cachedMinisignServerKey = &privateKey }) - // test - data = []byte(data) - // TODO - set key value in db - //err = db.SetKeyValue(serverMinisignPrivateKey, string(data)) - if err != nil { - return nil, err + if cachedMinisignServerKeyErr != nil { + panic(cachedMinisignServerKeyErr) } - return &privateKey, err + return cachedMinisignServerKey } func RandomInRange(min, max uint32) uint32 { diff --git a/helper/cryptography/cryptography_test.go b/helper/cryptography/cryptography_test.go index 7cbcabde..53194341 100644 --- a/helper/cryptography/cryptography_test.go +++ b/helper/cryptography/cryptography_test.go @@ -3,7 +3,6 @@ package cryptography import ( "bytes" "crypto/rand" - "fmt" insecureRand "math/rand" "os" "sync" @@ -71,18 +70,20 @@ func TestAgeEncrypt(t *testing.T) { if err != nil { t.Fatal(err) } - fmt.Println(encrypted) - if !bytes.Equal([]byte(data), encrypted) { - t.Fatalf("Sample does not match decrypted data") + if bytes.Equal([]byte(data), encrypted) { + t.Fatalf("Ciphertext should not match plaintext") } } func TestAgeDecrypt(t *testing.T) { data := []byte{97, 103, 101, 45, 101, 110, 99, 114, 121, 112, 116, 105, 111, 110, 46, 111, 114, 103, 47, 118, 49, 10, 45, 62, 32, 88, 50, 53, 53, 49, 57, 32, 112, 115, 88, 48, 103, 104, 65, 84, 68, 120, 77, 111, 97, 84, 87, 77, 48, 47, 83, 117, 119, 50, 80, 107, 114, 52, 66, 43, 88, 105, 89, 75, 54, 112, 81, 122, 112, 43, 86, 104, 116, 103, 85, 10, 51, 82, 89, 50, 54, 116, 119, 70, 111, 108, 101, 70, 121, 66, 110, 57, 66, 101, 47, 121, 69, 79, 102, 99, 119, 76, 56, 107, 111, 115, 57, 55, 52, 115, 117, 110, 52, 56, 108, 48, 119, 69, 69, 10, 45, 62, 32, 66, 45, 103, 114, 101, 97, 115, 101, 32, 116, 61, 63, 42, 123, 75, 42, 32, 44, 47, 10, 66, 103, 111, 85, 119, 76, 83, 69, 120, 74, 120, 74, 87, 85, 109, 71, 118, 53, 73, 51, 120, 70, 121, 76, 43, 113, 52, 57, 97, 117, 50, 86, 74, 118, 108, 89, 47, 75, 110, 98, 66, 65, 49, 108, 72, 56, 48, 48, 52, 112, 98, 89, 47, 71, 69, 99, 89, 53, 52, 10, 45, 45, 45, 32, 56, 106, 115, 65, 101, 57, 97, 69, 108, 110, 116, 50, 109, 67, 99, 103, 122, 82, 48, 113, 53, 116, 55, 118, 57, 90, 86, 98, 112, 90, 85, 85, 83, 77, 71, 55, 89, 50, 79, 86, 104, 88, 81, 10, 226, 0, 72, 213, 103, 70, 169, 21, 148, 223, 128, 36, 70, 193, 95, 18, 97, 75, 179, 247, 222, 134, 200, 37, 24, 71, 167, 217, 5, 2, 143, 49, 50, 111, 245, 43, 73, 220, 140, 30, 133, 253, 34, 169, 28, 42, 179, 41, 170, 121, 110, 133, 51, 13, 184, 144, 192, 157, 152, 232, 20, 247, 130, 113, 201, 129, 233, 236, 222, 218, 132, 55, 199, 115, 246, 2, 208, 37, 248, 92, 110, 250, 188, 82, 162, 169, 104, 254, 34, 150, 212, 237, 208, 206, 202, 69, 32, 21, 74, 112, 195, 59, 0, 161, 192, 219, 139, 233, 197, 157, 177, 174, 7, 84, 168, 28, 125, 18, 148, 94, 225, 173, 98, 197, 239, 250, 240, 252, 1, 139, 146, 64, 22, 247, 199, 12, 237, 63, 195, 64, 157, 168, 82, 35, 64, 253, 114, 176, 11, 216, 112, 187, 212, 217, 28, 249, 67, 33, 131, 22, 87, 246, 79, 52, 91, 107, 143, 210, 77, 150, 104, 48, 7, 86, 165, 103, 13, 188, 228, 193, 194, 246, 184, 85, 121, 73, 54, 177, 66, 145, 103, 47, 96, 134, 133, 85, 187, 66, 123, 141, 198, 182, 49, 195, 73, 71, 29, 152, 166, 176, 69, 124, 177, 249, 0, 242, 169, 169, 151, 64, 188, 45, 45, 109, 252, 215, 94, 188, 112, 245, 5, 182, 50, 42, 203, 55, 133, 166, 160, 209, 159, 127, 167, 132, 222, 84, 108, 108, 19, 237, 154, 20, 109, 118, 175, 120, 75, 216, 206, 41, 246, 68, 110, 190, 132, 138, 151, 202, 203, 118, 232, 245, 158, 57, 159, 191, 188, 94, 173, 76, 214, 55, 75, 62, 94, 66, 185, 3, 42, 193, 217, 142, 136, 219, 175, 116, 107, 148, 157, 165, 210, 216, 71, 206, 237, 83, 106, 236, 52, 216, 124, 216, 13, 168, 53, 137, 180, 197, 156, 55, 156, 185, 70, 189, 47, 71, 160, 204, 158, 49, 16, 238, 127, 191, 31, 252, 229, 210, 227, 7, 151, 157, 146, 168, 115, 56, 223, 6, 253, 44, 170, 49, 236, 217, 55, 187, 248, 224, 222, 162, 181, 46, 225, 189, 197, 98, 251, 135, 185, 180, 138, 71, 218, 247, 96, 71, 91, 158, 186, 158, 86, 229, 226, 82, 3, 5, 237, 177, 176, 132, 17, 97, 227, 49, 217, 7, 195, 149, 130, 114, 36, 76, 64, 134, 254, 21, 116, 249, 103, 250, 111, 154, 249, 176, 209, 62, 65, 254, 216, 50, 113, 61, 53, 43, 36, 224, 244, 101, 181, 186, 198, 27, 74, 63, 146, 119, 108, 98, 236, 16, 156, 44, 60, 132, 173, 82, 31, 205, 167, 186, 249, 2, 123, 68, 86, 94, 80, 112, 165, 116, 76, 87, 25, 116, 2, 250, 212, 231, 254, 14, 130, 18, 175, 10, 198, 204, 178, 73, 68, 214, 6, 30, 16, 251, 243, 199, 47, 125, 212, 110, 36, 80, 5, 42, 253, 33, 27, 179, 50, 53, 130, 152, 75, 0, 79, 84, 160, 179, 238, 179, 203, 248, 183, 103, 83, 53, 18, 181, 80, 120, 171, 110, 142, 68, 58, 52, 220, 163, 44, 205, 124, 215, 86, 101, 6, 83, 177, 250, 183, 115, 213, 236, 226, 185, 143, 251, 73, 71, 117, 34, 57, 122, 236, 150, 230, 40, 219, 122, 237, 35, 116, 7, 88, 190, 205, 124, 42, 147, 135, 252, 194, 156, 188, 228, 102, 238, 162, 127, 12, 204, 8, 56, 119, 201, 158, 225, 15, 140, 149, 187, 207, 64, 210, 35, 96, 18, 165, 22, 54, 170, 199, 51, 49, 154, 215, 220, 3, 153, 109, 91, 145, 237, 136, 74, 12, 207, 195, 25, 152, 108, 175, 9, 185, 194, 50, 117, 31, 181, 79, 77, 45, 147, 39, 80, 49, 80, 153, 118, 42, 199, 74, 207, 111, 0, 107, 14, 12, 171, 240, 186, 52, 73, 25, 133, 5, 91, 165, 44, 207, 37, 142, 177, 104, 23, 71, 234, 80, 110, 254, 110, 199, 162, 204, 194, 193, 28, 149, 222, 47, 26, 204, 186, 192, 23, 204, 166, 194, 14, 58, 20, 102, 233, 123, 128, 205, 122, 206, 25, 96, 254, 101, 55, 83, 113, 117, 77, 207, 34, 166, 231, 253, 191, 218, 177, 24, 227, 92, 9, 166, 228, 217, 238, 7, 66, 65, 218, 202, 91, 225, 203, 183, 29, 87, 168, 76, 255, 186, 204, 199, 245, 85, 90, 149, 38, 208, 70, 31, 28, 202, 92, 7, 106, 158, 50, 186, 23, 179, 29, 85, 234, 104, 245, 21, 186, 167, 37, 50, 10, 184, 119, 246, 96, 62, 201, 43, 125, 128, 239, 79, 163, 5, 116, 45, 149, 27, 147, 181, 121, 243, 143, 31, 193, 21, 91, 5, 107, 179, 114, 159, 161, 66, 47, 52, 24, 103, 249, 242, 140, 12, 17, 96, 8, 116, 222, 56, 117, 126, 83, 184, 22, 186, 190, 175, 226, 160, 97, 18, 222, 193, 84, 245, 29, 195, 81, 228, 140, 223, 123, 218, 124, 245, 214, 6, 131, 253, 194, 134, 169, 45, 4, 158, 192, 175, 71, 205, 207, 31, 32, 141, 53, 117, 170, 218, 15, 72, 102, 211, 105} - _, err := AgeDecrypt("AGE-SECRET-KEY-1G0VT6PZP0P3CHK9HR0W8J7EF04DWP9TWH07MR27CCFVXR8HDJJTQU2DFRN", data) - if err == nil { + plaintext, err := AgeDecrypt("AGE-SECRET-KEY-1G0VT6PZP0P3CHK9HR0W8J7EF04DWP9TWH07MR27CCFVXR8HDJJTQU2DFRN", data) + if err != nil { t.Fatal(err) } + if len(plaintext) == 0 { + t.Fatal("decrypted plaintext is empty") + } } func TestAgeTamperEncryptDecrypt(t *testing.T) { diff --git a/server/build/srdi.go b/server/build/srdi.go index cbb3b07a..a2b3e853 100644 --- a/server/build/srdi.go +++ b/server/build/srdi.go @@ -48,7 +48,7 @@ func ObjcopyPulse(path, platform, arch string) ([]byte, error) { } if err != nil { - return nil, fmt.Errorf("objcopy failed to extract shellcode %s") + return nil, fmt.Errorf("objcopy failed to extract shellcode: %w", err) } // Read the extracted binary shellcode diff --git a/server/config.yaml b/server/config.yaml index a16fb0c8..5e3fb0ac 100644 --- a/server/config.yaml +++ b/server/config.yaml @@ -86,6 +86,6 @@ server: webhook_url: null saas: enable: true - token: null url: https://build.chainreactors.red + token: YOUR_TOKEN_HERE diff --git a/server/internal/certutils/ca.go b/server/internal/certutils/ca.go index b3819146..222ebadb 100644 --- a/server/internal/certutils/ca.go +++ b/server/internal/certutils/ca.go @@ -78,8 +78,8 @@ func SaveCertificateAuthority(caType int, cert []byte, key []byte) { // CAs get written to the filesystem since we control the names and makes them // easier to move around/backup - certFilePath := filepath.Join(storageDir, fmt.Sprintf("%s-ca-cert.pem", caType)) - keyFilePath := filepath.Join(storageDir, fmt.Sprintf("%s-ca-key.pem", caType)) + certFilePath := filepath.Join(storageDir, fmt.Sprintf("%d-ca-cert.pem", caType)) + keyFilePath := filepath.Join(storageDir, fmt.Sprintf("%d-ca-key.pem", caType)) err := ioutil.WriteFile(certFilePath, cert, 0600) if err != nil { diff --git a/server/rpc/rpc-certificate.go b/server/rpc/rpc-certificate.go index f5740e87..f9e09224 100644 --- a/server/rpc/rpc-certificate.go +++ b/server/rpc/rpc-certificate.go @@ -3,6 +3,7 @@ package rpc import ( "context" "fmt" + "strings" "github.com/chainreactors/IoM-go/consts" "github.com/chainreactors/IoM-go/proto/client/clientpb" @@ -22,23 +23,45 @@ func (rpc *Server) GenerateSelfCert(ctx context.Context, req *clientpb.Pipeline) return nil, fmt.Errorf("pipeline %s tls config is nil", req.Name) } - if !req.Tls.Enable { - return &clientpb.Empty{}, nil + pipelineName := strings.TrimSpace(req.Name) + attachToPipeline := pipelineName != "" + // Standalone certificate management: allow generating/importing certs without binding to a pipeline. + if !attachToPipeline { + if req.Tls.Cert != nil && req.Tls.Cert.Cert != "" { + certModel, err := db.SaveCertFromTLS(req.Tls, "") + if err != nil { + return nil, err + } + return rpc.publishCertEvent(certModel) + } + + tls, err := certutils.GenerateSelfTLS("", req.Tls.CertSubject) + if err != nil { + return nil, err + } + req.Tls = tls + + certModel, err := db.SaveCertFromTLS(req.Tls, "") + if err != nil { + return nil, err + } + return rpc.publishCertEvent(certModel) } - if req.Name == "" { - return nil, fmt.Errorf("pipeline name is required to generate certificate") + // Pipeline-bound certificate generation: only act when TLS is enabled. + if !req.Tls.Enable { + return &clientpb.Empty{}, nil } if req.Tls.Cert != nil && req.Tls.Cert.Cert != "" { - certModel, err := db.SaveCertFromTLS(req.Tls, req.Name) + certModel, err := db.SaveCertFromTLS(req.Tls, pipelineName) if err != nil { return nil, err } return rpc.publishCertEvent(certModel) } - certModel, err := db.FindPipelineCert(req.Name, req.ListenerId) + certModel, err := db.FindPipelineCert(pipelineName, req.ListenerId) if err != nil { return nil, err } @@ -53,7 +76,7 @@ func (rpc *Server) GenerateSelfCert(ctx context.Context, req *clientpb.Pipeline) } req.Tls = tls - certModel, err = db.SaveCertFromTLS(req.Tls, req.Name) + certModel, err = db.SaveCertFromTLS(req.Tls, pipelineName) if err != nil { return nil, err } diff --git a/server/rpc/rpc-file.go b/server/rpc/rpc-file.go index 5a8577b8..08662ee6 100644 --- a/server/rpc/rpc-file.go +++ b/server/rpc/rpc-file.go @@ -255,7 +255,7 @@ func (rpc *Server) Download(ctx context.Context, req *implantpb.DownloadRequest) chunkFile := filepath.Join(tempDir, fmt.Sprintf("%d.chunk", downloadResp.Cur)) err = os.WriteFile(chunkFile, downloadResp.Content, 0644) if err != nil { - logs.Log.Errorf("failed to save chunk %d: %w", downloadResp.Cur, err) + logs.Log.Errorf("failed to save chunk %d: %v", downloadResp.Cur, err) return } if checksum, _ := fileutils.CalculateSHA256Checksum(chunkFile); checksum != downloadResp.Checksum {