Skip to content

Commit c6d639f

Browse files
aftersnow寂染
andauthored
feat: add modelfile auto generator (#82)
* add new package tools to add modctl toolkits Signed-off-by: 寂染 <[email protected]> * feat: support modelfile auto generation Signed-off-by: Zhao Chen <[email protected]> --------- Signed-off-by: 寂染 <[email protected]> Signed-off-by: Zhao Chen <[email protected]> Co-authored-by: 寂染 <[email protected]>
1 parent 82d5fb5 commit c6d639f

File tree

5 files changed

+513
-0
lines changed

5 files changed

+513
-0
lines changed

cmd/generate.go

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright 2024 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package cmd
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"strings"
23+
24+
"github.com/CloudNativeAI/modctl/pkg/modelfile"
25+
26+
"github.com/spf13/cobra"
27+
"github.com/spf13/viper"
28+
)
29+
30+
var genConfig = modelfile.NewModelfileGenConfig()
31+
32+
// modelfileGenCmd represents the modctl tools command for generate modelfile.
33+
var modelfileGenCmd = &cobra.Command{
34+
Use: "genmodelfile [flags] <path>",
35+
Short: "A command line tool for generating modelfile.",
36+
Args: cobra.ExactArgs(1),
37+
DisableAutoGenTag: true,
38+
SilenceUsage: true,
39+
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
40+
RunE: func(cmd *cobra.Command, args []string) error {
41+
return runGenModelfile(context.Background(), args[0])
42+
},
43+
}
44+
45+
// init initializes build command.
46+
func init() {
47+
flags := modelfileGenCmd.Flags()
48+
flags.StringVarP(&genConfig.Name, "name", "n", "", "Model name (string), such as llama3-8b-instruct, gpt2-xl, qwen2-vl-72b-instruct, etc.")
49+
flags.StringVarP(&genConfig.Version, "version", "v", "", "Model version (string), such as v1, v2, etc.")
50+
flags.StringVarP(&genConfig.OutputPath, "output", "o", "./", "Output path (string), such as /path/to/output.")
51+
flags.BoolVar(&genConfig.IgnoreUnrecognized, "ignore_unrecognized", false, "Ignore the unrecognized file types in directory.")
52+
flags.BoolVar(&genConfig.Overwrite, "overwrite", false, "Overwrite the existing modelfile.")
53+
flags.StringVar(&genConfig.Arch, "arch", "", "Model architecture (string), such as transformer, cnn, rnn, etc.")
54+
flags.StringVar(&genConfig.Family, "family", "", "Model family (string), such as llama3, gpt2, qwen2, etc.")
55+
flags.StringVar(&genConfig.Format, "format", "", "Model format (string), such as safetensors, pytorch, onnx, etc.")
56+
flags.StringVar(&genConfig.Paramsize, "paramsize", "", "Number of parameters in the model (string).")
57+
flags.StringVar(&genConfig.Precision, "precision", "", "Model precision (string), such as bf16, fp16, int8, etc.")
58+
flags.StringVar(&genConfig.Quantization, "quantization", "", "Model quantization (string), such as awq, gptq, etc.")
59+
60+
if err := viper.BindPFlags(flags); err != nil {
61+
panic(fmt.Errorf("bind cache list flags to viper: %w", err))
62+
}
63+
}
64+
65+
func runGenModelfile(ctx context.Context, modelPath string) error {
66+
if !strings.HasSuffix(modelPath, "/") {
67+
modelPath += "/"
68+
}
69+
70+
return modelfile.RunGenModelfile(ctx, modelPath, genConfig)
71+
}

cmd/root.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,4 +78,5 @@ func init() {
7878
rootCmd.AddCommand(pruneCmd)
7979
rootCmd.AddCommand(inspectCmd)
8080
rootCmd.AddCommand(extractCmd)
81+
rootCmd.AddCommand(modelfileGenCmd)
8182
}

pkg/modelfile/generate.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
/*
2+
* Copyright 2024 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package modelfile
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"os"
23+
"path/filepath"
24+
)
25+
26+
type ModelfileGenConfig struct {
27+
Name string
28+
Version string
29+
OutputPath string
30+
IgnoreUnrecognized bool
31+
Overwrite bool
32+
Arch string
33+
Family string
34+
Format string
35+
Paramsize string
36+
Precision string
37+
Quantization string
38+
}
39+
40+
func NewModelfileGenConfig() *ModelfileGenConfig {
41+
return &ModelfileGenConfig{
42+
Name: "",
43+
Version: "",
44+
OutputPath: "",
45+
IgnoreUnrecognized: false,
46+
Overwrite: false,
47+
Arch: "",
48+
Family: "",
49+
Format: "",
50+
Paramsize: "",
51+
Precision: "",
52+
Quantization: "",
53+
}
54+
}
55+
56+
func (c *ModelfileGenConfig) Validate() error {
57+
// if len(c.Name) == 0 {
58+
// return fmt.Errorf("model name is required")
59+
// }
60+
61+
if len(c.OutputPath) == 0 {
62+
return fmt.Errorf("output path is required")
63+
}
64+
65+
return nil
66+
}
67+
68+
func RunGenModelfile(ctx context.Context, modelPath string, genConfig *ModelfileGenConfig) error {
69+
if err := genConfig.Validate(); err != nil {
70+
return fmt.Errorf("failed to validate modelfile gen config: %w", err)
71+
}
72+
genPath := filepath.Join(genConfig.OutputPath, "Modelfile")
73+
74+
// Check if file exists
75+
if _, err := os.Stat(genPath); err == nil {
76+
if !genConfig.Overwrite {
77+
absPath, _ := filepath.Abs(genPath)
78+
return fmt.Errorf("Modelfile already exists at %s - use --overwrite to overwrite", absPath)
79+
}
80+
}
81+
82+
fmt.Printf("Generating modelfile for %s\n", modelPath)
83+
84+
modelfile, err := AutoModelfile(modelPath, genConfig)
85+
if err != nil {
86+
return fmt.Errorf("failed to generate modelfile: %w", err)
87+
}
88+
89+
// Save the modelfile to the output path
90+
if err := modelfile.SaveToFile(genPath); err != nil {
91+
return fmt.Errorf("failed to save modelfile: %w", err)
92+
}
93+
94+
// Read modelfile from disk and print it
95+
content, err := os.ReadFile(genPath)
96+
if err != nil {
97+
return fmt.Errorf("failed to read modelfile: %w", err)
98+
}
99+
fmt.Printf("Successfully generated modelfile:\n%s\n", string(content))
100+
101+
return nil
102+
}

0 commit comments

Comments
 (0)