Skip to content

Commit 484d8c8

Browse files
authored
feat: optimize modelfile command (#117)
Signed-off-by: Gaius <[email protected]>
1 parent a87c190 commit 484d8c8

File tree

12 files changed

+1596
-744
lines changed

12 files changed

+1596
-744
lines changed

cmd/generate.go

Lines changed: 0 additions & 72 deletions
This file was deleted.

cmd/modelfile/generate.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* Copyright 2025 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package modelfile
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"os"
23+
24+
configmodelfile "github.com/CloudNativeAI/modctl/pkg/config/modelfile"
25+
"github.com/CloudNativeAI/modctl/pkg/modelfile"
26+
27+
"github.com/spf13/cobra"
28+
"github.com/spf13/viper"
29+
)
30+
31+
var generateConfig = configmodelfile.NewGenerateConfig()
32+
33+
// generateCmd represents the modelfile tools command for generating modelfile.
34+
var generateCmd = &cobra.Command{
35+
Use: "generate [flags] <path>",
36+
Short: "A command line tool for generating modelfile in the workspace, the workspace must be a directory including model files and model configuration files",
37+
Args: cobra.ExactArgs(1),
38+
DisableAutoGenTag: true,
39+
SilenceUsage: true,
40+
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
41+
RunE: func(cmd *cobra.Command, args []string) error {
42+
if err := generateConfig.Convert(args[0]); err != nil {
43+
return err
44+
}
45+
46+
if err := generateConfig.Validate(); err != nil {
47+
return err
48+
}
49+
50+
return runGenerate(context.Background())
51+
},
52+
}
53+
54+
// init initializes generate command.
55+
func init() {
56+
flags := generateCmd.Flags()
57+
flags.StringVarP(&generateConfig.Name, "name", "n", "", "specify the model name, such as llama3-8b-instruct, gpt2-xl, qwen2-vl-72b-instruct, etc")
58+
flags.StringVar(&generateConfig.Arch, "arch", "", "specify the model architecture, such as transformer, cnn, rnn, etc")
59+
flags.StringVar(&generateConfig.Family, "family", "", "specify model family, such as llama3, gpt2, qwen2, etc")
60+
flags.StringVar(&generateConfig.Format, "format", "", "specify model format, such as safetensors, pytorch, onnx, etc")
61+
flags.StringVar(&generateConfig.ParamSize, "param-size", "", "specify number of model parameters, such as 8b, 16b, 32b, etc")
62+
flags.StringVar(&generateConfig.Precision, "precision", "", "specify model precision, such as bf16, fp16, int8, etc")
63+
flags.StringVar(&generateConfig.Quantization, "quantization", "", "specify model quantization, such as awq, gptq, etc")
64+
flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory")
65+
flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace")
66+
flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile")
67+
68+
if err := viper.BindPFlags(flags); err != nil {
69+
panic(fmt.Errorf("bind cache list flags to viper: %w", err))
70+
}
71+
}
72+
73+
// runGenerate runs the generate modelfile.
74+
func runGenerate(_ context.Context) error {
75+
fmt.Printf("Generating modelfile for %s\n", generateConfig.Workspace)
76+
modelfile, err := modelfile.NewModelfileByWorkspace(generateConfig.Workspace, generateConfig)
77+
if err != nil {
78+
return fmt.Errorf("failed to generate modelfile: %w", err)
79+
}
80+
81+
content := modelfile.Content()
82+
if err := os.WriteFile(generateConfig.Output, content, 0644); err != nil {
83+
return fmt.Errorf("failed to write modelfile: %w", err)
84+
}
85+
86+
fmt.Printf("Successfully generated modelfile:\n%s\n", string(content))
87+
return nil
88+
}

cmd/modelfile/modelfile.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
/*
2+
* Copyright 2025 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package modelfile
18+
19+
import (
20+
"github.com/sirupsen/logrus"
21+
22+
"github.com/spf13/cobra"
23+
"github.com/spf13/viper"
24+
)
25+
26+
// RootCmd represents the modelfile tools command for modelfile operation.
27+
var RootCmd = &cobra.Command{
28+
Use: "modelfile",
29+
Short: "A command line tool for modelfile operation",
30+
Args: cobra.ExactArgs(1),
31+
DisableAutoGenTag: true,
32+
SilenceUsage: true,
33+
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
34+
RunE: func(cmd *cobra.Command, args []string) error {
35+
logrus.Debug("modctl modelfile is running")
36+
37+
return nil
38+
},
39+
}
40+
41+
// init initializes modelfile command.
42+
func init() {
43+
flags := RootCmd.Flags()
44+
45+
if err := viper.BindPFlags(flags); err != nil {
46+
panic(err)
47+
}
48+
49+
// Add sub command.
50+
RootCmd.AddCommand(generateCmd)
51+
}

cmd/root.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"os/signal"
2222
"syscall"
2323

24+
"github.com/CloudNativeAI/modctl/cmd/modelfile"
2425
"github.com/CloudNativeAI/modctl/pkg/config"
2526

2627
"github.com/spf13/cobra"
@@ -86,6 +87,6 @@ func init() {
8687
rootCmd.AddCommand(pruneCmd)
8788
rootCmd.AddCommand(inspectCmd)
8889
rootCmd.AddCommand(extractCmd)
89-
rootCmd.AddCommand(modelfileGenCmd)
9090
rootCmd.AddCommand(tagCmd)
91+
rootCmd.AddCommand(modelfile.RootCmd)
9192
}

docs/getting-started.md

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,19 @@ $ ./output/modctl -h
1818

1919
## Usage
2020

21+
### Modelfile
22+
23+
#### Generate
24+
25+
Generate a Modelfile for the model artifact in the current directory(workspace),
26+
you need go to the directory where the model artifact is located and
27+
run the following command. Then the `Modelfile` will be generated in the current
28+
directory(workspace).
29+
30+
```shell
31+
$ modctl modelfile generate .
32+
```
33+
2134
### Build
2235

2336
Build the model artifact you need to prepare a Modelfile describe your expected layout of the model artifact in your model repo.
@@ -26,41 +39,43 @@ Example of Modelfile:
2639

2740
```shell
2841
# Model name (string), such as llama3-8b-instruct, gpt2-xl, qwen2-vl-72b-instruct, etc.
29-
NAME gemma-2b
42+
name gemma-2b
3043

3144
# Model architecture (string), such as transformer, cnn, rnn, etc.
32-
ARCH transformer
45+
arch transformer
3346

3447
# Model family (string), such as llama3, gpt2, qwen2, etc.
35-
FAMILY gemma
48+
family gemma
3649

3750
# Model format (string), such as onnx, tensorflow, pytorch, etc.
38-
FORMAT safetensors
51+
format safetensors
3952

4053
# Number of parameters in the model (integer).
41-
PARAMSIZE 16
54+
paramsize 16
4255

4356
# Model precision (string), such as bf16, fp16, int8, etc.
44-
PRECISION bf16
57+
precision bf16
4558

4659
# Model quantization (string), such as awq, gptq, etc.
47-
QUANTIZATION awq
60+
quantization awq
4861

4962
# Specify model configuration file, support glob path pattern.
50-
CONFIG config.json
63+
config config.json
5164

5265
# Specify model configuration file, support glob path pattern.
53-
CONFIG generation_config.json
66+
config generation_config.json
5467

5568
# Model weight, support glob path pattern.
56-
MODEL *.safetensors
69+
model *.safetensors
5770

5871
# Specify code, support glob path pattern.
59-
CODE *.py
72+
code *.py
6073

6174
# Specify documentation, support glob path pattern.
62-
DOC *.md
75+
doc *.md
6376

77+
# Specify dataset, support glob path pattern.
78+
dataset *.csv
6479
```
6580

6681
Then run the following command to build the model artifact:

pkg/config/modelfile/modelfile.go

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
/*
2+
* Copyright 2025 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package modelfile
18+
19+
import (
20+
"fmt"
21+
"os"
22+
"path/filepath"
23+
"strings"
24+
)
25+
26+
// DefaultModelfileName is the default name of the modelfile.
27+
const DefaultModelfileName = "Modelfile"
28+
29+
type GenerateConfig struct {
30+
Workspace string
31+
Name string
32+
Version string
33+
Output string
34+
IgnoreUnrecognizedFileTypes bool
35+
Overwrite bool
36+
Arch string
37+
Family string
38+
Format string
39+
ParamSize string
40+
Precision string
41+
Quantization string
42+
}
43+
44+
func NewGenerateConfig() *GenerateConfig {
45+
return &GenerateConfig{
46+
Workspace: ".",
47+
Name: "",
48+
Version: "",
49+
Output: "",
50+
IgnoreUnrecognizedFileTypes: false,
51+
Overwrite: false,
52+
Arch: "",
53+
Family: "",
54+
Format: "",
55+
ParamSize: "",
56+
Precision: "",
57+
Quantization: "",
58+
}
59+
}
60+
61+
func (g *GenerateConfig) Convert(workspace string) error {
62+
modelfilePath := filepath.Join(g.Output, DefaultModelfileName)
63+
absModelfilePath, err := filepath.Abs(modelfilePath)
64+
if err != nil {
65+
return err
66+
}
67+
g.Output = absModelfilePath
68+
69+
if !strings.HasSuffix(workspace, "/") {
70+
workspace += "/"
71+
}
72+
73+
absWorkspace, err := filepath.Abs(workspace)
74+
if err != nil {
75+
return err
76+
}
77+
g.Workspace = absWorkspace
78+
return nil
79+
}
80+
81+
func (g *GenerateConfig) Validate() error {
82+
// Check if the output path exists modelfile, if so, check if we can overwrite it.
83+
// If the output path does not exist, we can create the modelfile.
84+
if _, err := os.Stat(g.Output); err == nil {
85+
if !g.Overwrite {
86+
return fmt.Errorf("Modelfile already exists at %s - use --overwrite to overwrite", g.Output)
87+
}
88+
}
89+
90+
return nil
91+
}

0 commit comments

Comments
 (0)