Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func init() {
flags.StringVar(&genConfig.Arch, "arch", "", "Model architecture (string), such as transformer, cnn, rnn, etc.")
flags.StringVar(&genConfig.Family, "family", "", "Model family (string), such as llama3, gpt2, qwen2, etc.")
flags.StringVar(&genConfig.Format, "format", "", "Model format (string), such as safetensors, pytorch, onnx, etc.")
flags.StringVar(&genConfig.Paramsize, "paramsize", "", "Number of parameters in the model (string).")
flags.StringVar(&genConfig.Paramsize, "paramsize", "", "Number of parameters in the model (string), such as 7B, 13B, 72B, etc.")
flags.StringVar(&genConfig.Precision, "precision", "", "Model precision (string), such as bf16, fp16, int8, etc.")
flags.StringVar(&genConfig.Quantization, "quantization", "", "Model quantization (string), such as awq, gptq, etc.")

Expand All @@ -63,6 +63,7 @@ func init() {
}

func runGenModelfile(ctx context.Context, modelPath string) error {

if !strings.HasSuffix(modelPath, "/") {
modelPath += "/"
}
Expand Down
25 changes: 17 additions & 8 deletions pkg/modelfile/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,29 +69,38 @@ func RunGenModelfile(ctx context.Context, modelPath string, genConfig *Modelfile
if err := genConfig.Validate(); err != nil {
return fmt.Errorf("failed to validate modelfile gen config: %w", err)
}
genPath := filepath.Join(genConfig.OutputPath, "Modelfile")

// Check if file exists
// Convert modelPath to absolute path
realModelPath, err := filepath.Abs(modelPath)
if err != nil {
return fmt.Errorf("failed to get absolute path for model: %w", err)
}

// check if file exists
genPath := filepath.Join(genConfig.OutputPath, "Modelfile")
genPath, err = filepath.Abs(genPath)
if err != nil {
return fmt.Errorf("failed to get absolute path for modelfile: %w", err)
}
if _, err := os.Stat(genPath); err == nil {
if !genConfig.Overwrite {
absPath, _ := filepath.Abs(genPath)
return fmt.Errorf("Modelfile already exists at %s - use --overwrite to overwrite", absPath)
return fmt.Errorf("Modelfile already exists at %s - use --overwrite to overwrite", genPath)
}
}

fmt.Printf("Generating modelfile for %s\n", modelPath)
fmt.Printf("Generating modelfile for %s\n", realModelPath)

modelfile, err := AutoModelfile(modelPath, genConfig)
modelfile, err := AutoModelfile(realModelPath, genConfig)
if err != nil {
return fmt.Errorf("failed to generate modelfile: %w", err)
}

// Save the modelfile to the output path
// save the modelfile to the output path
if err := modelfile.SaveToFile(genPath); err != nil {
return fmt.Errorf("failed to save modelfile: %w", err)
}

// Read modelfile from disk and print it
// read modelfile from disk and print it
content, err := os.ReadFile(genPath)
if err != nil {
return fmt.Errorf("failed to read modelfile: %w", err)
Expand Down
63 changes: 52 additions & 11 deletions pkg/modelfile/modelfile.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,11 @@ func isFileType(filename string, patterns []string) bool {

// isSkippable checks if the filename matches any of the skip patterns
func isSkippable(filename string) bool {
// Special handling for current and parent directory
if filename == "." || filename == ".." {
return false
}

// Convert filename to lowercase for case-insensitive comparison
lowerFilename := strings.ToLower(filename)
for _, pattern := range skipPatterns {
Expand All @@ -195,6 +200,32 @@ func isSkippable(filename string) bool {
return false
}

// cleanModelName sanitizes a string to create a valid model name
func cleanModelName(name string) string {
// Remove any trailing slashes first
name = strings.TrimRight(name, "/\\")

// Replace invalid characters with underscores
name = strings.Map(func(r rune) rune {
// Allow alphanumeric characters
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') {
return r
}
// Replace everything else with underscore
return '_'
}, name)

// Remove leading/trailing underscores
name = strings.Trim(name, "_")

// If name is empty after cleaning, return a default
if name == "" {
return "unnamed_model"
}

return name
}

// NewModelfile creates a new modelfile by the path of the modelfile.
// It parses the modelfile and returns the modelfile interface.
func NewModelfile(path string) (Modelfile, error) {
Expand Down Expand Up @@ -284,16 +315,21 @@ func overwriteModelConfig(mf *modelfile, config *ModelfileGenConfig) {
// AutoModelfile creates a new modelfile by the path of the model directory.
// It walks the directory and returns the auto-generated modelfile interface.
func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
// check if the config is nil
if config == nil {
return nil, fmt.Errorf("config cannot be nil")
}

mf := &modelfile{
config: hashset.New(),
model: hashset.New(),
code: hashset.New(),
dataset: hashset.New(),
}

// Use directory name as model name if config.name is empty
// use directory name as model name if config.name is empty
if config.Name == "" {
mf.name = filepath.Base(path)
mf.name = cleanModelName(filepath.Base(path))
} else {
mf.name = config.Name
}
Expand All @@ -306,7 +342,7 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {

filename := info.Name()

// Skip hidden and skippable files/directories
// skip hidden and skippable files/directories
if isSkippable(filename) {
if info.IsDir() {
return filepath.SkipDir
Expand All @@ -318,7 +354,7 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
return nil
}

// Get relative path from the base directory
// get relative path from the base directory
relPath, err := filepath.Rel(path, fullPath)
if err != nil {
return err
Expand All @@ -332,11 +368,11 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
case isFileType(filename, codeFilePatterns):
mf.code.Add(relPath)
default:
// Skip unrecognized files if IgnoreUnrecognized is true
// skip unrecognized files if IgnoreUnrecognized is true
if config.IgnoreUnrecognized {
return nil
}
return fmt.Errorf("unknown file type: %s", filename)
return fmt.Errorf("unknown file type: %s - use --ignore-unrecognized to ignore, and edit the Modelfile manually", filename)
}

return nil
Expand All @@ -346,12 +382,17 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
return nil, err
}

// Get the model config from the config.json file
// check if model files are found
if mf.model.Size() == 0 {
return nil, fmt.Errorf("no recognized model files found in directory - you may need to edit the Modelfile manually")
}

// get the model config from the config.json file
if err := parseModelConfig(path, mf); err != nil {
return nil, err
}

// Overwrite the modelfile configurations with the provided config values
// overwrite the modelfile configurations with the provided config values
overwriteModelConfig(mf, config)

return mf, nil
Expand Down Expand Up @@ -533,7 +574,7 @@ func (mf *modelfile) SaveToFile(path string) error {
// generate time in the first line
content += fmt.Sprintf("# Generated at %s\n", time.Now().Format(time.RFC3339))

// Add single value commands
// add single value commands
if mf.name != "" {
content += "\n# Model name\n"
content += fmt.Sprintf("NAME %s\n", mf.name)
Expand Down Expand Up @@ -563,7 +604,7 @@ func (mf *modelfile) SaveToFile(path string) error {
content += fmt.Sprintf("QUANTIZATION %s\n", mf.quantization)
}

// Add multi-value commands
// add multi-value commands
content += "\n# Config files (Generated from the files in the model directory)\n"
content += "# Supported file types: " + strings.Join(configFilePatterns, ", ") + "\n"
configs := mf.GetConfigs()
Expand All @@ -588,6 +629,6 @@ func (mf *modelfile) SaveToFile(path string) error {
content += fmt.Sprintf("MODEL %s\n", model)
}

// Write to file
// write to file
return os.WriteFile(path, []byte(content), 0644)
}
169 changes: 169 additions & 0 deletions pkg/modelfile/modelfile_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package modelfile
import (
"errors"
"os"
"path/filepath"
"sort"
"testing"

Expand Down Expand Up @@ -247,3 +248,171 @@ name bar
os.Remove(tmpfile.Name())
}
}

func TestAutoModelfile(t *testing.T) {
testCases := []struct {
name string
files map[string]string
config *ModelfileGenConfig
expectErr error
validate func(*testing.T, Modelfile)
}{
{
name: "basic model directory",
files: map[string]string{
"config.json": `{"model_type": "llama", "transformers_version": "1.0", "torch_dtype": "float16"}`,
"generation_config.json": `{}`,
"tokenizer.model": "dummy content",
"pytorch_model.bin": "dummy content",
"model.safetensors": "dummy content",
"train.py": "print('hello')",
"README.md": "# Model Documentation",
".git/config": "should be ignored",
"__pycache__/cache.pyc": "should be ignored",
},
config: &ModelfileGenConfig{
Name: "llama2-7b",
Format: "safetensors",
Paramsize: "7B",
Quantization: "q4_k_m",
IgnoreUnrecognized: true,
},
expectErr: nil,
validate: func(t *testing.T, mf Modelfile) {
assert := assert.New(t)

// Check configs (sorted)
expectedConfigs := []string{
"README.md",
"config.json",
"generation_config.json",
"tokenizer.model",
}
configs := mf.GetConfigs()
sort.Strings(configs)
assert.Equal(expectedConfigs, configs)

// Check models (sorted)
expectedModels := []string{
"model.safetensors",
"pytorch_model.bin",
}
models := mf.GetModels()
sort.Strings(models)
assert.Equal(expectedModels, models)

// Check codes (sorted)
expectedCodes := []string{
"train.py",
}
codes := mf.GetCodes()
sort.Strings(codes)
assert.Equal(expectedCodes, codes)

// Check other fields
assert.Equal("llama2-7b", mf.GetName())
assert.Equal("transformer", mf.GetArch()) // from config.json
assert.Equal("llama", mf.GetFamily()) // from config.json
assert.Equal("safetensors", mf.GetFormat())
assert.Equal("7B", mf.GetParamsize())
assert.Equal("float16", mf.GetPrecision()) // from config.json
assert.Equal("q4_k_m", mf.GetQuantization())
},
},
{
name: "unrecognized files without ignore flag",
files: map[string]string{
"unknown.xyz": "some content",
},
config: &ModelfileGenConfig{
Name: "test-model",
IgnoreUnrecognized: false,
},
expectErr: errors.New("unknown file type: unknown.xyz - use --ignore-unrecognized to ignore, and edit the Modelfile manually"),
},
{
name: "empty directory",
files: map[string]string{},
config: &ModelfileGenConfig{Name: "empty-model"},
expectErr: errors.New("no recognized model files found in directory - you may need to edit the Modelfile manually"),
},
{
name: "invalid config json",
files: map[string]string{"config.json": `{"model_type": "llama", invalid json`},
config: &ModelfileGenConfig{Name: "invalid-config"},
expectErr: errors.New("no recognized model files found in directory - you may need to edit the Modelfile manually"),
},
{
name: "nested directories",
files: map[string]string{"config.json": `{"model_type": "llama"}`, "models/shard1.safetensors": "dummy content", "models/shard2.safetensors": "dummy content", "configs/main.json": "dummy content", "src/train.py": "print('hello')"},
config: &ModelfileGenConfig{Name: "nested-model", IgnoreUnrecognized: true},
expectErr: nil,
validate: func(t *testing.T, mf Modelfile) {
assert := assert.New(t)
models := mf.GetModels()
sort.Strings(models)
assert.Equal([]string{"models/shard1.safetensors", "models/shard2.safetensors"}, models)

codes := mf.GetCodes()
sort.Strings(codes)
assert.Equal([]string{"src/train.py"}, codes)
},
},
{
name: "special characters in paths",
files: map[string]string{
"config.json": `{"model_type": "llama"}`,
"model with spaces.safetensors": "dummy content",
"特殊字符.bin": "dummy content",
"src/test-file.py": "print('hello')",
},
config: &ModelfileGenConfig{
Name: "special-chars",
IgnoreUnrecognized: true,
},
validate: func(t *testing.T, mf Modelfile) {
assert := assert.New(t)
models := mf.GetModels()
sort.Strings(models)
assert.Equal([]string{"model with spaces.safetensors", "特殊字符.bin"}, models)
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Create temporary directory
tmpDir, err := os.MkdirTemp("", "modelfile_test_*")
assert.NoError(t, err)
defer os.RemoveAll(tmpDir)

// Create test files
for path, content := range tc.files {
fullPath := filepath.Join(tmpDir, path)

// Create parent directories if needed
err := os.MkdirAll(filepath.Dir(fullPath), 0755)
assert.NoError(t, err)

err = os.WriteFile(fullPath, []byte(content), 0644)
assert.NoError(t, err)
}

// Run AutoModelfile
mf, err := AutoModelfile(tmpDir, tc.config)

if tc.expectErr != nil {
assert.Error(t, err)
assert.Equal(t, tc.expectErr.Error(), err.Error())
assert.Nil(t, mf)
return
}

assert.NoError(t, err)
assert.NotNil(t, mf)

// Run validation
tc.validate(t, mf)
})
}
}