Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 45 additions & 16 deletions pkg/session/prompt_file.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,61 @@ import (
"path/filepath"
)

func readPromptFile(workDir, filename string) (string, error) {
current, err := filepath.Abs(workDir)
// readPromptFiles looks for a prompt file in the working directory hierarchy
// and in the user's home folder. If found in both locations, both contents are returned.
// The working directory content is returned first, followed by the home folder content.
func readPromptFiles(workDir, filename string) ([]string, error) {
var results []string

// Look in the working directory hierarchy
workDirPath := findFileInHierarchy(workDir, filename)
if workDirPath != "" {
content, err := os.ReadFile(workDirPath)
if err != nil {
return nil, err
}
results = append(results, string(content))
}

// Look in the home folder (skip if already found there)
if homeDir, err := os.UserHomeDir(); err == nil {
homePath := filepath.Join(homeDir, filename)
if homePath != workDirPath && isFile(homePath) {
content, err := os.ReadFile(homePath)
if err != nil {
return nil, err
}
results = append(results, string(content))
}
}

return results, nil
}

// findFileInHierarchy searches for a file starting from the given directory
// and traversing up the directory tree. Returns the path if found, empty string otherwise.
func findFileInHierarchy(startDir, filename string) string {
current, err := filepath.Abs(startDir)
if err != nil {
return "", err
return ""
}

for {
path := filepath.Join(current, filename)

info, err := os.Stat(path)
if err != nil {
if !os.IsNotExist(err) {
return "", err
}
} else if !info.IsDir() {
data, err := os.ReadFile(path)
if err != nil {
return "", err
}
return string(data), nil
if isFile(path) {
return path
}

parent := filepath.Dir(current)
if parent == current {
return "", nil
return ""
}
current = parent
}
}

// isFile returns true if path exists and is a regular file.
func isFile(path string) bool {
info, err := os.Stat(path)
return err == nil && !info.IsDir()
}
151 changes: 135 additions & 16 deletions pkg/session/prompt_file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,59 +9,178 @@ import (
"github.com/stretchr/testify/require"
)

func TestReadPromptFile(t *testing.T) {
func TestReadPromptFiles(t *testing.T) {
t.Parallel()

dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, "agents.md"), []byte("content"), 0o644)
// Use a unique filename to avoid conflicts with files in home directory
filename := "test_prompt_unique_12345.md"
err := os.WriteFile(filepath.Join(dir, filename), []byte("content"), 0o644)
require.NoError(t, err)

additionalPrompt, err := readPromptFile(dir, "agents.md")
additionalPrompts, err := readPromptFiles(dir, filename)
require.NoError(t, err)
assert.Equal(t, "content", additionalPrompt)
require.Len(t, additionalPrompts, 1)
assert.Equal(t, "content", additionalPrompts[0])
}

func TestReadPromptFileParent(t *testing.T) {
func TestReadPromptFilesParent(t *testing.T) {
t.Parallel()

dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, "agents.md"), []byte("content"), 0o644)
// Use a unique filename to avoid conflicts with files in home directory
filename := "test_prompt_parent_12345.md"
err := os.WriteFile(filepath.Join(dir, filename), []byte("content"), 0o644)
require.NoError(t, err)

child := filepath.Join(dir, "child")
err = os.Mkdir(child, 0o755)
require.NoError(t, err)

additionalPrompt, err := readPromptFile(child, "agents.md")
additionalPrompts, err := readPromptFiles(child, filename)
require.NoError(t, err)
assert.Equal(t, "content", additionalPrompt)
require.Len(t, additionalPrompts, 1)
assert.Equal(t, "content", additionalPrompts[0])
}

func TestReadPromptFileReadFirst(t *testing.T) {
func TestReadPromptFilesReadFirst(t *testing.T) {
t.Parallel()

dir := t.TempDir()
err := os.WriteFile(filepath.Join(dir, "agents.md"), []byte("parent"), 0o644)
// Use a unique filename to avoid conflicts with files in home directory
filename := "test_prompt_readfirst_12345.md"
err := os.WriteFile(filepath.Join(dir, filename), []byte("parent"), 0o644)
require.NoError(t, err)

child := filepath.Join(dir, "child")
err = os.Mkdir(child, 0o755)
require.NoError(t, err)

err = os.WriteFile(filepath.Join(child, "agents.md"), []byte("child"), 0o644)
err = os.WriteFile(filepath.Join(child, filename), []byte("child"), 0o644)
require.NoError(t, err)

additionalPrompt, err := readPromptFile(child, "agents.md")
additionalPrompts, err := readPromptFiles(child, filename)
require.NoError(t, err)
assert.Equal(t, "child", additionalPrompt)
require.Len(t, additionalPrompts, 1)
assert.Equal(t, "child", additionalPrompts[0])
}

func TestReadNoPromptFile(t *testing.T) {
func TestReadNoPromptFiles(t *testing.T) {
t.Parallel()

dir := t.TempDir()
// Use a unique filename that won't exist anywhere
filename := "test_prompt_nonexistent_12345.md"

additionalPrompt, err := readPromptFile(dir, "agents.md")
additionalPrompts, err := readPromptFiles(dir, filename)
require.NoError(t, err)
assert.Empty(t, additionalPrompt)
assert.Empty(t, additionalPrompts)
}

func TestReadPromptFilesFromWorkDirAndHome(t *testing.T) {
t.Parallel()

// Use a unique filename for this test
filename := "test_prompt_workdir_and_home_12345.md"

// Create a temp dir to simulate working directory
workDir := t.TempDir()
err := os.WriteFile(filepath.Join(workDir, filename), []byte("workdir content"), 0o644)
require.NoError(t, err)

// Get the actual home directory and check if we can write to it
homeDir, err := os.UserHomeDir()
require.NoError(t, err)

homePath := filepath.Join(homeDir, filename)
// Check if file already exists in home
_, existsErr := os.Stat(homePath)
fileExisted := existsErr == nil

// Create file in home directory
err = os.WriteFile(homePath, []byte("home content"), 0o644)
if err != nil {
t.Skip("Cannot write to home directory")
}
// Clean up only if we created it
if !fileExisted {
t.Cleanup(func() {
os.Remove(homePath)
})
}

additionalPrompts, err := readPromptFiles(workDir, filename)
require.NoError(t, err)
require.Len(t, additionalPrompts, 2)
assert.Equal(t, "workdir content", additionalPrompts[0])
assert.Equal(t, "home content", additionalPrompts[1])
}

func TestReadPromptFilesFromHomeOnly(t *testing.T) {
t.Parallel()

// Use a unique filename for this test
filename := "test_prompt_home_only_12345.md"

// Create a temp dir without the prompt file
workDir := t.TempDir()

// Get the actual home directory and check if we can write to it
homeDir, err := os.UserHomeDir()
require.NoError(t, err)

homePath := filepath.Join(homeDir, filename)
// Check if file already exists in home
_, existsErr := os.Stat(homePath)
fileExisted := existsErr == nil

// Create file in home directory
err = os.WriteFile(homePath, []byte("home content"), 0o644)
if err != nil {
t.Skip("Cannot write to home directory")
}
// Clean up only if we created it
if !fileExisted {
t.Cleanup(func() {
os.Remove(homePath)
})
}

additionalPrompts, err := readPromptFiles(workDir, filename)
require.NoError(t, err)
require.Len(t, additionalPrompts, 1)
assert.Equal(t, "home content", additionalPrompts[0])
}

func TestReadPromptFilesDeduplication(t *testing.T) {
t.Parallel()

// Test that if working directory is under home, we don't duplicate
homeDir, err := os.UserHomeDir()
require.NoError(t, err)

// Use a unique filename for this test
filename := "test_prompt_dedup_12345.md"
homePath := filepath.Join(homeDir, filename)

// Check if file already exists in home
_, existsErr := os.Stat(homePath)
fileExisted := existsErr == nil

// Create file only in home directory
err = os.WriteFile(homePath, []byte("home content"), 0o644)
if err != nil {
t.Skip("Cannot write to home directory")
}
if !fileExisted {
t.Cleanup(func() {
os.Remove(homePath)
})
}

// When working directory is home, should only return one result
additionalPrompts, err := readPromptFiles(homeDir, filename)
require.NoError(t, err)
require.Len(t, additionalPrompts, 1)
assert.Equal(t, "home content", additionalPrompts[0])
}
4 changes: 2 additions & 2 deletions pkg/session/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -570,13 +570,13 @@ func buildContextSpecificSystemMessages(a *agent.Agent, s *Session) []chat.Messa
}

for _, prompt := range a.AddPromptFiles() {
additionalPrompt, err := readPromptFile(wd, prompt)
additionalPrompts, err := readPromptFiles(wd, prompt)
if err != nil {
slog.Error("reading prompt file", "file", prompt, "error", err)
continue
}

if additionalPrompt != "" {
for _, additionalPrompt := range additionalPrompts {
messages = append(messages, chat.Message{
Role: chat.MessageRoleSystem,
Content: additionalPrompt,
Expand Down