Skip to content

Commit 6292a5b

Browse files
committed
refactor: improve modelfile generation error handling and validation
- Add more descriptive error messages for modelfile generation - Enhance validation checks in AutoModelfile method - Improve handling of unrecognized files and empty directories - Add nil config check and more informative error messages - Update comments for better code readability Signed-off-by: Zhao Chen <zhaochen.zju@gmail.com>
1 parent c4f1730 commit 6292a5b

File tree

3 files changed

+72
-15
lines changed

3 files changed

+72
-15
lines changed

pkg/modelfile/generate.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func RunGenModelfile(ctx context.Context, modelPath string, genConfig *Modelfile
7171
}
7272
genPath := filepath.Join(genConfig.OutputPath, "Modelfile")
7373

74-
// Check if file exists
74+
// check if file exists
7575
if _, err := os.Stat(genPath); err == nil {
7676
if !genConfig.Overwrite {
7777
absPath, _ := filepath.Abs(genPath)
@@ -86,12 +86,12 @@ func RunGenModelfile(ctx context.Context, modelPath string, genConfig *Modelfile
8686
return fmt.Errorf("failed to generate modelfile: %w", err)
8787
}
8888

89-
// Save the modelfile to the output path
89+
// save the modelfile to the output path
9090
if err := modelfile.SaveToFile(genPath); err != nil {
9191
return fmt.Errorf("failed to save modelfile: %w", err)
9292
}
9393

94-
// Read modelfile from disk and print it
94+
// read modelfile from disk and print it
9595
content, err := os.ReadFile(genPath)
9696
if err != nil {
9797
return fmt.Errorf("failed to read modelfile: %w", err)

pkg/modelfile/modelfile.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -310,14 +310,19 @@ func overwriteModelConfig(mf *modelfile, config *ModelfileGenConfig) {
310310
// AutoModelfile creates a new modelfile by the path of the model directory.
311311
// It walks the directory and returns the auto-generated modelfile interface.
312312
func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
313+
// check if the config is nil
314+
if config == nil {
315+
return nil, fmt.Errorf("config cannot be nil")
316+
}
317+
313318
mf := &modelfile{
314319
config: hashset.New(),
315320
model: hashset.New(),
316321
code: hashset.New(),
317322
dataset: hashset.New(),
318323
}
319324

320-
// Use directory name as model name if config.name is empty
325+
// use directory name as model name if config.name is empty
321326
if config.Name == "" {
322327
mf.name = cleanModelName(filepath.Base(path))
323328
} else {
@@ -332,7 +337,7 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
332337

333338
filename := info.Name()
334339

335-
// Skip hidden and skippable files/directories
340+
// skip hidden and skippable files/directories
336341
if isSkippable(filename) {
337342
if info.IsDir() {
338343
return filepath.SkipDir
@@ -344,7 +349,7 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
344349
return nil
345350
}
346351

347-
// Get relative path from the base directory
352+
// get relative path from the base directory
348353
relPath, err := filepath.Rel(path, fullPath)
349354
if err != nil {
350355
return err
@@ -358,11 +363,11 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
358363
case isFileType(filename, codeFilePatterns):
359364
mf.code.Add(relPath)
360365
default:
361-
// Skip unrecognized files if IgnoreUnrecognized is true
366+
// skip unrecognized files if IgnoreUnrecognized is true
362367
if config.IgnoreUnrecognized {
363368
return nil
364369
}
365-
return fmt.Errorf("unknown file type: %s", filename)
370+
return fmt.Errorf("unknown file type: %s - use --ignore-unrecognized to ignore, and edit the Modelfile manually", filename)
366371
}
367372

368373
return nil
@@ -372,12 +377,17 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
372377
return nil, err
373378
}
374379

375-
// Get the model config from the config.json file
380+
// check if model files are found
381+
if mf.model.Size() == 0 {
382+
return nil, fmt.Errorf("no recognized model files found in directory - you may need to edit the Modelfile manually")
383+
}
384+
385+
// get the model config from the config.json file
376386
if err := parseModelConfig(path, mf); err != nil {
377387
return nil, err
378388
}
379389

380-
// Overwrite the modelfile configurations with the provided config values
390+
// overwrite the modelfile configurations with the provided config values
381391
overwriteModelConfig(mf, config)
382392

383393
return mf, nil
@@ -559,7 +569,7 @@ func (mf *modelfile) SaveToFile(path string) error {
559569
// generate time in the first line
560570
content += fmt.Sprintf("# Generated at %s\n", time.Now().Format(time.RFC3339))
561571

562-
// Add single value commands
572+
// add single value commands
563573
if mf.name != "" {
564574
content += "\n# Model name\n"
565575
content += fmt.Sprintf("NAME %s\n", mf.name)
@@ -589,7 +599,7 @@ func (mf *modelfile) SaveToFile(path string) error {
589599
content += fmt.Sprintf("QUANTIZATION %s\n", mf.quantization)
590600
}
591601

592-
// Add multi-value commands
602+
// add multi-value commands
593603
content += "\n# Config files (Generated from the files in the model directory)\n"
594604
content += "# Supported file types: " + strings.Join(configFilePatterns, ", ") + "\n"
595605
configs := mf.GetConfigs()
@@ -614,6 +624,6 @@ func (mf *modelfile) SaveToFile(path string) error {
614624
content += fmt.Sprintf("MODEL %s\n", model)
615625
}
616626

617-
// Write to file
627+
// write to file
618628
return os.WriteFile(path, []byte(content), 0644)
619629
}

pkg/modelfile/modelfile_test.go

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ name bar
252252
func TestAutoModelfile(t *testing.T) {
253253
testCases := []struct {
254254
name string
255-
files map[string]string // map of relative path to file content
255+
files map[string]string
256256
config *ModelfileGenConfig
257257
expectErr error
258258
validate func(*testing.T, Modelfile)
@@ -328,7 +328,54 @@ func TestAutoModelfile(t *testing.T) {
328328
Name: "test-model",
329329
IgnoreUnrecognized: false,
330330
},
331-
expectErr: errors.New("unknown file type: unknown.xyz"),
331+
expectErr: errors.New("unknown file type: unknown.xyz - use --ignore-unrecognized to ignore, and edit the Modelfile manually"),
332+
},
333+
{
334+
name: "empty directory",
335+
files: map[string]string{},
336+
config: &ModelfileGenConfig{Name: "empty-model"},
337+
expectErr: errors.New("no recognized model files found in directory - you may need to edit the Modelfile manually"),
338+
},
339+
{
340+
name: "invalid config json",
341+
files: map[string]string{"config.json": `{"model_type": "llama", invalid json`},
342+
config: &ModelfileGenConfig{Name: "invalid-config"},
343+
expectErr: errors.New("no recognized model files found in directory - you may need to edit the Modelfile manually"),
344+
},
345+
{
346+
name: "nested directories",
347+
files: map[string]string{"config.json": `{"model_type": "llama"}`, "models/shard1.safetensors": "dummy content", "models/shard2.safetensors": "dummy content", "configs/main.json": "dummy content", "src/train.py": "print('hello')"},
348+
config: &ModelfileGenConfig{Name: "nested-model", IgnoreUnrecognized: true},
349+
expectErr: nil,
350+
validate: func(t *testing.T, mf Modelfile) {
351+
assert := assert.New(t)
352+
models := mf.GetModels()
353+
sort.Strings(models)
354+
assert.Equal([]string{"models/shard1.safetensors", "models/shard2.safetensors"}, models)
355+
356+
codes := mf.GetCodes()
357+
sort.Strings(codes)
358+
assert.Equal([]string{"src/train.py"}, codes)
359+
},
360+
},
361+
{
362+
name: "special characters in paths",
363+
files: map[string]string{
364+
"config.json": `{"model_type": "llama"}`,
365+
"model with spaces.safetensors": "dummy content",
366+
"特殊字符.bin": "dummy content",
367+
"src/test-file.py": "print('hello')",
368+
},
369+
config: &ModelfileGenConfig{
370+
Name: "special-chars",
371+
IgnoreUnrecognized: true,
372+
},
373+
validate: func(t *testing.T, mf Modelfile) {
374+
assert := assert.New(t)
375+
models := mf.GetModels()
376+
sort.Strings(models)
377+
assert.Equal([]string{"model with spaces.safetensors", "特殊字符.bin"}, models)
378+
},
332379
},
333380
}
334381

0 commit comments

Comments
 (0)