Skip to content

Commit 2a39894

Browse files
authored
feat: add support for customizable prompt templates from directories (#224)
- Add `promptFolder` variable and setting it via command line flag - Create and validate `promptFolder` during configuration initialization - Implement loading of custom prompt templates from `promptFolder` - Implement the `prompt` command to load default prompt data - Add function to get raw data of the template - Initialize `templates` as an empty map to avoid redundant checks - Add support to load templates from a specified directory - Add tests for loading templates from a directory Signed-off-by: appleboy <[email protected]>
1 parent 19b7803 commit 2a39894

File tree

6 files changed

+178
-6
lines changed

6 files changed

+178
-6
lines changed

cmd/cmd.go

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ var rootCmd = &cobra.Command{
2121

2222
// Used for flags.
2323
var (
24-
cfgFile string
25-
replacer = strings.NewReplacer("-", "_", ".", "_")
24+
cfgFile string
25+
promptFolder string
26+
replacer = strings.NewReplacer("-", "_", ".", "_")
2627
)
2728

2829
const (
@@ -34,12 +35,14 @@ func init() {
3435
cobra.OnInitialize(initConfig)
3536

3637
rootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.config/codegpt/.codegpt.yaml)")
38+
rootCmd.PersistentFlags().StringVar(&promptFolder, "prompt_folder", "", "prompt folder (default is $HOME/.config/codegpt/prompt)")
3739
rootCmd.AddCommand(versionCmd)
3840
rootCmd.AddCommand(configCmd)
3941
rootCmd.AddCommand(commitCmd)
4042
rootCmd.AddCommand(hookCmd)
4143
rootCmd.AddCommand(reviewCmd)
4244
rootCmd.AddCommand(CompletionCmd)
45+
rootCmd.AddCommand(promptCmd)
4346

4447
// hide completion command
4548
rootCmd.CompletionOptions.HiddenDefaultCmd = true
@@ -75,6 +78,30 @@ func initConfig() {
7578
}
7679
}
7780

81+
if promptFolder != "" {
82+
viper.Set("prompt_folder", promptFolder)
83+
if file.IsFile(promptFolder) {
84+
log.Fatalf("prompt folder %s is a file", promptFolder)
85+
}
86+
// create the prompt folder if it doesn't exist
87+
if !file.IsDir(promptFolder) {
88+
if err := os.MkdirAll(promptFolder, os.ModePerm); err != nil {
89+
log.Fatal(err)
90+
}
91+
}
92+
} else {
93+
// Find home directory.
94+
home, err := os.UserHomeDir()
95+
cobra.CheckErr(err)
96+
targetFolder := path.Join(home, ".config", "codegpt", "prompt")
97+
if !file.IsDir(targetFolder) {
98+
if err := os.MkdirAll(targetFolder, os.ModePerm); err != nil {
99+
log.Fatal(err)
100+
}
101+
}
102+
viper.Set("prompt_folder", targetFolder)
103+
}
104+
78105
viper.AutomaticEnv()
79106
viper.SetEnvKeyReplacer(replacer)
80107

cmd/hepler.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,13 @@ func check() error {
6565
return fmt.Errorf("template variables file not found: %s", templateVarsFile)
6666
}
6767

68+
// load custom prompt
69+
promptFolder := viper.GetString("prompt_folder")
70+
if promptFolder != "" {
71+
if err := util.LoadTemplatesFromDir(promptFolder); err != nil {
72+
return fmt.Errorf("failed to load custom prompt templates: %s", err)
73+
}
74+
}
75+
6876
return nil
6977
}

cmd/prompt.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package cmd
2+
3+
import (
4+
"os"
5+
"path"
6+
7+
"github.com/appleboy/CodeGPT/prompt"
8+
"github.com/erikgeiser/promptkit/confirmation"
9+
"github.com/fatih/color"
10+
"github.com/spf13/cobra"
11+
"github.com/spf13/viper"
12+
)
13+
14+
var loadPromptData bool
15+
16+
func init() {
17+
promptCmd.PersistentFlags().BoolVar(&loadPromptData, "load", false,
18+
"load default prompt data")
19+
}
20+
21+
var defaultPromptDataKeys = []string{
22+
prompt.CodeReviewTemplate,
23+
prompt.SummarizeFileDiffTemplate,
24+
prompt.SummarizeTitleTemplate,
25+
prompt.ConventionalCommitTemplate,
26+
}
27+
28+
// promptCmd represents the command to load default prompt data.
29+
// It uses the "prompt" keyword and provides a short description: "load default prompt data".
30+
// The command executes the RunE function which checks if the loadPromptData flag is set.
31+
// If set, it prompts the user for confirmation to load the default prompt data, which will overwrite existing data.
32+
// Upon confirmation, it retrieves the prompt folder path from the configuration and saves the default prompt data keys to the specified folder.
33+
// If any error occurs during the process, it returns the error.
34+
var promptCmd = &cobra.Command{
35+
Use: "prompt",
36+
Short: "load default prompt data",
37+
RunE: func(cmd *cobra.Command, args []string) error {
38+
if !loadPromptData {
39+
return nil
40+
}
41+
42+
confirm, err := confirmation.New("Do you want to load default prompt data, will overwrite your data", confirmation.No).RunPrompt()
43+
if err != nil || !confirm {
44+
return err
45+
}
46+
47+
folder := viper.GetString("prompt_folder")
48+
for _, key := range defaultPromptDataKeys {
49+
if err := savePromptData(folder, key); err != nil {
50+
return err
51+
}
52+
}
53+
return nil
54+
},
55+
}
56+
57+
func savePromptData(folder, key string) error {
58+
// load default prompt data
59+
out, err := prompt.GetRawData(key)
60+
if err != nil {
61+
return err
62+
}
63+
64+
// save out to file
65+
target := path.Join(folder, key)
66+
if err := os.WriteFile(target, out, 0o600); err != nil {
67+
return err
68+
}
69+
color.Cyan("save %s to %s", key, target)
70+
return nil
71+
}

prompt/prompt.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,9 @@ func init() { //nolint:gochecknoinits
2828
log.Fatal(err)
2929
}
3030
}
31+
32+
// GetRawData returns the raw data of the template with the given name.
33+
func GetRawData(name string) ([]byte, error) {
34+
key := "templates/" + name
35+
return templatesFS.ReadFile(key)
36+
}

util/template.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@ import (
66
"fmt"
77
"html/template"
88
"io/fs"
9+
"os"
910
)
1011

1112
// Data defines a custom type for the template data.
1213
type Data map[string]interface{}
1314

1415
var (
15-
templates map[string]*template.Template
16+
templates = make(map[string]*template.Template)
1617
templatesDir = "templates"
1718
)
1819

@@ -67,9 +68,6 @@ func GetTemplateByBytes(name string, data map[string]interface{}) ([]byte, error
6768
// LoadTemplates loads all the templates found in the templates directory from the embedded filesystem.
6869
// It returns an error if reading the directory or parsing any template fails.
6970
func LoadTemplates(files embed.FS) error {
70-
if templates == nil {
71-
templates = make(map[string]*template.Template)
72-
}
7371
tmplFiles, err := fs.ReadDir(files, templatesDir)
7472
if err != nil {
7573
return err
@@ -89,3 +87,24 @@ func LoadTemplates(files embed.FS) error {
8987
}
9088
return nil
9189
}
90+
91+
func LoadTemplatesFromDir(dir string) error {
92+
tmplFiles, err := fs.ReadDir(os.DirFS(dir), ".")
93+
if err != nil {
94+
return err
95+
}
96+
97+
for _, tmpl := range tmplFiles {
98+
if tmpl.IsDir() {
99+
continue
100+
}
101+
102+
pt, err := template.ParseFS(os.DirFS(dir), tmpl.Name())
103+
if err != nil {
104+
return err
105+
}
106+
107+
templates[tmpl.Name()] = pt
108+
}
109+
return nil
110+
}

util/template_test.go

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package util
22

33
import (
4+
"bytes"
45
"html/template"
6+
"os"
57
"testing"
68
)
79

@@ -69,3 +71,42 @@ func TestProcessTemplate(t *testing.T) {
6971
t.Errorf("Unexpected output. Got: %v, Want: %v", buf.String(), expected)
7072
}
7173
}
74+
75+
func TestLoadTemplatesFromDir(t *testing.T) {
76+
// Create a temporary directory for testing
77+
tempDir := t.TempDir()
78+
79+
// Create a sample template file in the temporary directory
80+
templateContent := "Hello, {{.Name}}!"
81+
templateFile := "test.tmpl"
82+
err := os.WriteFile(tempDir+"/"+templateFile, []byte(templateContent), 0o600)
83+
if err != nil {
84+
t.Fatalf("Failed to create test template file: %v", err)
85+
}
86+
87+
// Load templates from the temporary directory
88+
err = LoadTemplatesFromDir(tempDir)
89+
if err != nil {
90+
t.Fatalf("Failed to load templates from directory: %v", err)
91+
}
92+
93+
// Check if the template was loaded correctly
94+
tmpl, ok := templates[templateFile]
95+
if !ok {
96+
t.Fatalf("Template %s not found in loaded templates", templateFile)
97+
}
98+
99+
// Process the loaded template
100+
data := Data{"Name": "World"}
101+
var buf bytes.Buffer
102+
err = tmpl.Execute(&buf, data)
103+
if err != nil {
104+
t.Fatalf("Failed to execute loaded template: %v", err)
105+
}
106+
107+
// Check the output
108+
expected := "Hello, World!"
109+
if buf.String() != expected {
110+
t.Errorf("Unexpected output. Got: %v, Want: %v", buf.String(), expected)
111+
}
112+
}

0 commit comments

Comments
 (0)