diff --git a/CLAUDE.md b/CLAUDE.md index a6528b14..01c3eb24 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -10,6 +10,8 @@ This is the Inference Gateway CLI (`infer`), a Go-based command-line tool for ma **Note: All commands should be run with `flox activate -- ` to ensure the proper development environment is activated.** +**IMPORTANT: Always run `task setup` first when working with a fresh checkout of the repository to ensure all dependencies are properly installed.** + ### Setup Development Environment ```bash flox activate -- task setup diff --git a/internal/container/container.go b/internal/container/container.go index 9154c4db..371d8c15 100644 --- a/internal/container/container.go +++ b/internal/container/container.go @@ -48,43 +48,35 @@ func NewServiceContainer(cfg *config.Config) *ServiceContainer { // initializeDomainServices creates and wires domain service implementations func (c *ServiceContainer) initializeDomainServices() { - // Create conversation repository c.conversationRepo = services.NewInMemoryConversationRepository() - // Create model service c.modelService = services.NewHTTPModelService( c.config.Gateway.URL, c.config.Gateway.APIKey, ) - // Create tool service first (needed by chat service) + c.fileService = services.NewLocalFileService() + if c.config.Tools.Enabled { - c.toolService = services.NewLLMToolService(c.config) + c.toolService = services.NewLLMToolService(c.config, c.fileService) } else { c.toolService = services.NewNoOpToolService() } - // Create chat service with tool service c.chatService = services.NewStreamingChatService( c.config.Gateway.URL, c.config.Gateway.APIKey, c.config.Gateway.Timeout, c.toolService, ) - - // Create file service - c.fileService = services.NewLocalFileService() } // initializeUIComponents creates UI components and theme func (c *ServiceContainer) initializeUIComponents() { - // Create theme based on configuration c.theme = ui.NewDefaultTheme() - // Create layout manager c.layout = ui.NewDefaultLayout() - // Create component factory c.componentFactory = ui.NewComponentFactory(c.theme, c.layout, c.modelService) } diff --git a/internal/domain/interfaces.go b/internal/domain/interfaces.go index 562eafc6..895f9c9c 100644 --- a/internal/domain/interfaces.go +++ b/internal/domain/interfaces.go @@ -103,6 +103,7 @@ type ToolService interface { type FileService interface { ListProjectFiles() ([]string, error) ReadFile(path string) (string, error) + ReadFileLines(path string, startLine, endLine int) (string, error) ValidateFile(path string) error GetFileInfo(path string) (FileInfo, error) } diff --git a/internal/services/file.go b/internal/services/file.go index f948ee02..67d0843b 100644 --- a/internal/services/file.go +++ b/internal/services/file.go @@ -1,6 +1,7 @@ package services import ( + "bufio" "fmt" "io/fs" "os" @@ -161,6 +162,46 @@ func (s *LocalFileService) ReadFile(path string) (string, error) { return string(content), nil } +func (s *LocalFileService) ReadFileLines(path string, startLine, endLine int) (string, error) { + if err := s.ValidateFile(path); err != nil { + return "", err + } + + file, err := os.Open(path) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer func() { + _ = file.Close() + }() + + scanner := bufio.NewScanner(file) + var result strings.Builder + currentLine := 1 + + for scanner.Scan() { + if startLine > 0 && currentLine < startLine { + currentLine++ + continue + } + if endLine > 0 && currentLine > endLine { + break + } + + if result.Len() > 0 { + result.WriteString("\n") + } + result.WriteString(scanner.Text()) + currentLine++ + } + + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("failed to read file lines: %w", err) + } + + return result.String(), nil +} + func (s *LocalFileService) ValidateFile(path string) error { if path == "" { return fmt.Errorf("file path cannot be empty") diff --git a/internal/services/tool.go b/internal/services/tool.go index b08457fc..29ae753f 100644 --- a/internal/services/tool.go +++ b/internal/services/tool.go @@ -22,17 +22,29 @@ type ToolResult struct { Duration string `json:"duration"` } +// FileReadResult represents the result of a file read operation +type FileReadResult struct { + FilePath string `json:"file_path"` + Content string `json:"content"` + Size int64 `json:"size"` + StartLine int `json:"start_line,omitempty"` + EndLine int `json:"end_line,omitempty"` + Error string `json:"error,omitempty"` +} + // LLMToolService implements ToolService with direct tool execution type LLMToolService struct { - config *config.Config - enabled bool + config *config.Config + fileService domain.FileService + enabled bool } // NewLLMToolService creates a new LLM tool service -func NewLLMToolService(cfg *config.Config) *LLMToolService { +func NewLLMToolService(cfg *config.Config, fileService domain.FileService) *LLMToolService { return &LLMToolService{ - config: cfg, - enabled: cfg.Tools.Enabled, + config: cfg, + fileService: fileService, + enabled: cfg.Tools.Enabled, } } @@ -62,6 +74,36 @@ func (s *LLMToolService) ListTools() []domain.ToolDefinition { "required": []string{"command"}, }, }, + { + Name: "Read", + Description: "Read file content from the filesystem with optional line range", + Parameters: map[string]interface{}{ + "type": "object", + "properties": map[string]interface{}{ + "file_path": map[string]interface{}{ + "type": "string", + "description": "The path to the file to read", + }, + "start_line": map[string]interface{}{ + "type": "integer", + "description": "Starting line number (1-indexed, optional)", + "minimum": 1, + }, + "end_line": map[string]interface{}{ + "type": "integer", + "description": "Ending line number (1-indexed, optional)", + "minimum": 1, + }, + "format": map[string]interface{}{ + "type": "string", + "description": "Output format (text or json)", + "enum": []string{"text", "json"}, + "default": "text", + }, + }, + "required": []string{"file_path"}, + }, + }, } } @@ -72,40 +114,9 @@ func (s *LLMToolService) ExecuteTool(ctx context.Context, name string, args map[ switch name { case "Bash": - command, ok := args["command"].(string) - if !ok { - return "", fmt.Errorf("command parameter is required and must be a string") - } - - format, ok := args["format"].(string) - if !ok { - format = "text" - } - - result, err := s.executeBash(ctx, command) - if err != nil { - return "", fmt.Errorf("bash execution failed: %w", err) - } - - if format == "json" { - jsonOutput, err := json.MarshalIndent(result, "", " ") - if err != nil { - return "", fmt.Errorf("failed to marshal result: %w", err) - } - return string(jsonOutput), nil - } - - output := fmt.Sprintf("Command: %s\n", result.Command) - output += fmt.Sprintf("Exit Code: %d\n", result.ExitCode) - output += fmt.Sprintf("Duration: %s\n", result.Duration) - - if result.Error != "" { - output += fmt.Sprintf("Error: %s\n", result.Error) - } - - output += fmt.Sprintf("Output:\n%s", result.Output) - return output, nil - + return s.executeBashTool(ctx, args) + case "Read": + return s.executeReadTool(args) default: return "", fmt.Errorf("unknown tool: %s", name) } @@ -134,18 +145,14 @@ func (s *LLMToolService) ValidateTool(name string, args map[string]interface{}) return fmt.Errorf("tool '%s' is not available", name) } - if name == "Bash" { - command, ok := args["command"].(string) - if !ok { - return fmt.Errorf("command parameter is required and must be a string") - } - - if !s.isCommandAllowed(command) { - return fmt.Errorf("command not whitelisted: %s", command) - } + switch name { + case "Bash": + return s.validateBashTool(args) + case "Read": + return s.validateReadTool(args) + default: + return nil } - - return nil } // executeBash executes a bash command with security validation @@ -179,6 +186,34 @@ func (s *LLMToolService) executeBash(ctx context.Context, command string) (*Tool return result, nil } +// executeRead reads a file with optional line range +func (s *LLMToolService) executeRead(filePath string, startLine, endLine int) (*FileReadResult, error) { + result := &FileReadResult{ + FilePath: filePath, + StartLine: startLine, + EndLine: endLine, + } + + var content string + var err error + + if startLine > 0 || endLine > 0 { + content, err = s.fileService.ReadFileLines(filePath, startLine, endLine) + } else { + content, err = s.fileService.ReadFile(filePath) + } + + if err != nil { + result.Error = err.Error() + return result, nil + } + + result.Content = content + result.Size = int64(len(content)) + + return result, nil +} + // isCommandAllowed checks if a command is whitelisted func (s *LLMToolService) isCommandAllowed(command string) bool { command = strings.TrimSpace(command) @@ -222,3 +257,152 @@ func (s *NoOpToolService) IsToolEnabled(name string) bool { func (s *NoOpToolService) ValidateTool(name string, args map[string]interface{}) error { return fmt.Errorf("tools are not enabled") } + +// executeBashTool handles Bash tool execution +func (s *LLMToolService) executeBashTool(ctx context.Context, args map[string]interface{}) (string, error) { + command, ok := args["command"].(string) + if !ok { + return "", fmt.Errorf("command parameter is required and must be a string") + } + + format, ok := args["format"].(string) + if !ok { + format = "text" + } + + result, err := s.executeBash(ctx, command) + if err != nil { + return "", fmt.Errorf("bash execution failed: %w", err) + } + + if format == "json" { + jsonOutput, err := json.MarshalIndent(result, "", " ") + if err != nil { + return "", fmt.Errorf("failed to marshal result: %w", err) + } + return string(jsonOutput), nil + } + + return s.formatBashResult(result), nil +} + +// executeReadTool handles Read tool execution +func (s *LLMToolService) executeReadTool(args map[string]interface{}) (string, error) { + filePath, ok := args["file_path"].(string) + if !ok { + return "", fmt.Errorf("file_path parameter is required and must be a string") + } + + format, ok := args["format"].(string) + if !ok { + format = "text" + } + + var startLine, endLine int + if startLineFloat, ok := args["start_line"].(float64); ok { + startLine = int(startLineFloat) + } + if endLineFloat, ok := args["end_line"].(float64); ok { + endLine = int(endLineFloat) + } + + result, err := s.executeRead(filePath, startLine, endLine) + if err != nil { + return "", fmt.Errorf("file read failed: %w", err) + } + + if format == "json" { + jsonOutput, err := json.MarshalIndent(result, "", " ") + if err != nil { + return "", fmt.Errorf("failed to marshal result: %w", err) + } + return string(jsonOutput), nil + } + + return s.formatReadResult(result), nil +} + +// validateBashTool validates Bash tool arguments +func (s *LLMToolService) validateBashTool(args map[string]interface{}) error { + command, ok := args["command"].(string) + if !ok { + return fmt.Errorf("command parameter is required and must be a string") + } + + if !s.isCommandAllowed(command) { + return fmt.Errorf("command not whitelisted: %s", command) + } + + return nil +} + +// validateReadTool validates Read tool arguments +func (s *LLMToolService) validateReadTool(args map[string]interface{}) error { + filePath, ok := args["file_path"].(string) + if !ok { + return fmt.Errorf("file_path parameter is required and must be a string") + } + + if err := s.fileService.ValidateFile(filePath); err != nil { + return fmt.Errorf("file validation failed: %w", err) + } + + return s.validateLineNumbers(args) +} + +// validateLineNumbers validates start_line and end_line parameters +func (s *LLMToolService) validateLineNumbers(args map[string]interface{}) error { + var startLine float64 + var hasStartLine bool + + if startLineFloat, ok := args["start_line"].(float64); ok { + if startLineFloat < 1 { + return fmt.Errorf("start_line must be >= 1") + } + startLine = startLineFloat + hasStartLine = true + } + + if endLineFloat, ok := args["end_line"].(float64); ok { + if endLineFloat < 1 { + return fmt.Errorf("end_line must be >= 1") + } + if hasStartLine && endLineFloat < startLine { + return fmt.Errorf("end_line must be >= start_line") + } + } + + return nil +} + +// formatBashResult formats bash execution result for text output +func (s *LLMToolService) formatBashResult(result *ToolResult) string { + output := fmt.Sprintf("Command: %s\n", result.Command) + output += fmt.Sprintf("Exit Code: %d\n", result.ExitCode) + output += fmt.Sprintf("Duration: %s\n", result.Duration) + + if result.Error != "" { + output += fmt.Sprintf("Error: %s\n", result.Error) + } + + output += fmt.Sprintf("Output:\n%s", result.Output) + return output +} + +// formatReadResult formats read result for text output +func (s *LLMToolService) formatReadResult(result *FileReadResult) string { + output := fmt.Sprintf("File: %s\n", result.FilePath) + if result.StartLine > 0 { + output += fmt.Sprintf("Lines: %d", result.StartLine) + if result.EndLine > 0 && result.EndLine != result.StartLine { + output += fmt.Sprintf("-%d", result.EndLine) + } + output += "\n" + } + output += fmt.Sprintf("Size: %d bytes\n", result.Size) + if result.Error != "" { + output += fmt.Sprintf("Error: %s\n", result.Error) + } + output += fmt.Sprintf("Content:\n%s", result.Content) + return output +}