Skip to content

Commit b5bcb07

Browse files
authored
test: add TestAutoModelfile (#84)
* test: add TestAutoModelfile Signed-off-by: Zhao Chen <[email protected]> * feat: add cleanModelName function for model name sanitization Adds a new cleanModelName function to sanitize model names by: - Removing trailing slashes - Replacing invalid characters with underscores - Trimming leading/trailing underscores - Providing a default name for empty inputs Updates AutoModelfile to use the new cleanModelName function when generating model names Signed-off-by: Zhao Chen <[email protected]> * refactor: change Paramsize type from string to uint64 This change modifies the Paramsize field in ModelfileGenConfig from a string to a uint64, updating related code in generate.go, modelfile.go, and modelfile_test.go to support numeric parameter size representation. The changes include: - Updated CLI flag type from string to uint64 - Modified ModelfileGenConfig struct - Updated test cases to use numeric parameter sizes - Converted Paramsize to string representation when needed Signed-off-by: Zhao Chen <[email protected]> * refactor: improve modelfile generation error handling and validation - Add more descriptive error messages for modelfile generation - Enhance validation checks in AutoModelfile method - Improve handling of unrecognized files and empty directories - Add nil config check and more informative error messages - Update comments for better code readability Signed-off-by: Zhao Chen <[email protected]> * refactor: change Paramsize back to string type - Reverted Paramsize type from uint64 to string - Updated CLI flag to accept string parameter sizes like "7B", "13B" - Modified ModelfileGenConfig and related code to support string-based parameter sizes - Updated test cases to use string parameter size representation Signed-off-by: Zhao Chen <[email protected]> * refactor: improve modelfile generation path handling and skippable file detection - Convert model and modelfile paths to absolute paths - Remove redundant absolute path conversion - Add special handling for current and parent directory in skippable file detection - Simplify error handling for path-related operations Signed-off-by: Zhao Chen <[email protected]> --------- Signed-off-by: Zhao Chen <[email protected]>
1 parent 315dec6 commit b5bcb07

File tree

4 files changed

+240
-20
lines changed

4 files changed

+240
-20
lines changed

cmd/generate.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ func init() {
5353
flags.StringVar(&genConfig.Arch, "arch", "", "Model architecture (string), such as transformer, cnn, rnn, etc.")
5454
flags.StringVar(&genConfig.Family, "family", "", "Model family (string), such as llama3, gpt2, qwen2, etc.")
5555
flags.StringVar(&genConfig.Format, "format", "", "Model format (string), such as safetensors, pytorch, onnx, etc.")
56-
flags.StringVar(&genConfig.Paramsize, "paramsize", "", "Number of parameters in the model (string).")
56+
flags.StringVar(&genConfig.Paramsize, "paramsize", "", "Number of parameters in the model (string), such as 7B, 13B, 72B, etc.")
5757
flags.StringVar(&genConfig.Precision, "precision", "", "Model precision (string), such as bf16, fp16, int8, etc.")
5858
flags.StringVar(&genConfig.Quantization, "quantization", "", "Model quantization (string), such as awq, gptq, etc.")
5959

@@ -63,6 +63,7 @@ func init() {
6363
}
6464

6565
func runGenModelfile(ctx context.Context, modelPath string) error {
66+
6667
if !strings.HasSuffix(modelPath, "/") {
6768
modelPath += "/"
6869
}

pkg/modelfile/generate.go

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,29 +69,38 @@ func RunGenModelfile(ctx context.Context, modelPath string, genConfig *Modelfile
6969
if err := genConfig.Validate(); err != nil {
7070
return fmt.Errorf("failed to validate modelfile gen config: %w", err)
7171
}
72-
genPath := filepath.Join(genConfig.OutputPath, "Modelfile")
7372

74-
// Check if file exists
73+
// Convert modelPath to absolute path
74+
realModelPath, err := filepath.Abs(modelPath)
75+
if err != nil {
76+
return fmt.Errorf("failed to get absolute path for model: %w", err)
77+
}
78+
79+
// check if file exists
80+
genPath := filepath.Join(genConfig.OutputPath, "Modelfile")
81+
genPath, err = filepath.Abs(genPath)
82+
if err != nil {
83+
return fmt.Errorf("failed to get absolute path for modelfile: %w", err)
84+
}
7585
if _, err := os.Stat(genPath); err == nil {
7686
if !genConfig.Overwrite {
77-
absPath, _ := filepath.Abs(genPath)
78-
return fmt.Errorf("Modelfile already exists at %s - use --overwrite to overwrite", absPath)
87+
return fmt.Errorf("Modelfile already exists at %s - use --overwrite to overwrite", genPath)
7988
}
8089
}
8190

82-
fmt.Printf("Generating modelfile for %s\n", modelPath)
91+
fmt.Printf("Generating modelfile for %s\n", realModelPath)
8392

84-
modelfile, err := AutoModelfile(modelPath, genConfig)
93+
modelfile, err := AutoModelfile(realModelPath, genConfig)
8594
if err != nil {
8695
return fmt.Errorf("failed to generate modelfile: %w", err)
8796
}
8897

89-
// Save the modelfile to the output path
98+
// save the modelfile to the output path
9099
if err := modelfile.SaveToFile(genPath); err != nil {
91100
return fmt.Errorf("failed to save modelfile: %w", err)
92101
}
93102

94-
// Read modelfile from disk and print it
103+
// read modelfile from disk and print it
95104
content, err := os.ReadFile(genPath)
96105
if err != nil {
97106
return fmt.Errorf("failed to read modelfile: %w", err)

pkg/modelfile/modelfile.go

Lines changed: 52 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,11 @@ func isFileType(filename string, patterns []string) bool {
183183

184184
// isSkippable checks if the filename matches any of the skip patterns
185185
func isSkippable(filename string) bool {
186+
// Special handling for current and parent directory
187+
if filename == "." || filename == ".." {
188+
return false
189+
}
190+
186191
// Convert filename to lowercase for case-insensitive comparison
187192
lowerFilename := strings.ToLower(filename)
188193
for _, pattern := range skipPatterns {
@@ -195,6 +200,32 @@ func isSkippable(filename string) bool {
195200
return false
196201
}
197202

203+
// cleanModelName sanitizes a string to create a valid model name
204+
func cleanModelName(name string) string {
205+
// Remove any trailing slashes first
206+
name = strings.TrimRight(name, "/\\")
207+
208+
// Replace invalid characters with underscores
209+
name = strings.Map(func(r rune) rune {
210+
// Allow alphanumeric characters
211+
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') {
212+
return r
213+
}
214+
// Replace everything else with underscore
215+
return '_'
216+
}, name)
217+
218+
// Remove leading/trailing underscores
219+
name = strings.Trim(name, "_")
220+
221+
// If name is empty after cleaning, return a default
222+
if name == "" {
223+
return "unnamed_model"
224+
}
225+
226+
return name
227+
}
228+
198229
// NewModelfile creates a new modelfile by the path of the modelfile.
199230
// It parses the modelfile and returns the modelfile interface.
200231
func NewModelfile(path string) (Modelfile, error) {
@@ -284,16 +315,21 @@ func overwriteModelConfig(mf *modelfile, config *ModelfileGenConfig) {
284315
// AutoModelfile creates a new modelfile by the path of the model directory.
285316
// It walks the directory and returns the auto-generated modelfile interface.
286317
func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
318+
// check if the config is nil
319+
if config == nil {
320+
return nil, fmt.Errorf("config cannot be nil")
321+
}
322+
287323
mf := &modelfile{
288324
config: hashset.New(),
289325
model: hashset.New(),
290326
code: hashset.New(),
291327
dataset: hashset.New(),
292328
}
293329

294-
// Use directory name as model name if config.name is empty
330+
// use directory name as model name if config.name is empty
295331
if config.Name == "" {
296-
mf.name = filepath.Base(path)
332+
mf.name = cleanModelName(filepath.Base(path))
297333
} else {
298334
mf.name = config.Name
299335
}
@@ -306,7 +342,7 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
306342

307343
filename := info.Name()
308344

309-
// Skip hidden and skippable files/directories
345+
// skip hidden and skippable files/directories
310346
if isSkippable(filename) {
311347
if info.IsDir() {
312348
return filepath.SkipDir
@@ -318,7 +354,7 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
318354
return nil
319355
}
320356

321-
// Get relative path from the base directory
357+
// get relative path from the base directory
322358
relPath, err := filepath.Rel(path, fullPath)
323359
if err != nil {
324360
return err
@@ -332,11 +368,11 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
332368
case isFileType(filename, codeFilePatterns):
333369
mf.code.Add(relPath)
334370
default:
335-
// Skip unrecognized files if IgnoreUnrecognized is true
371+
// skip unrecognized files if IgnoreUnrecognized is true
336372
if config.IgnoreUnrecognized {
337373
return nil
338374
}
339-
return fmt.Errorf("unknown file type: %s", filename)
375+
return fmt.Errorf("unknown file type: %s - use --ignore-unrecognized to ignore, and edit the Modelfile manually", filename)
340376
}
341377

342378
return nil
@@ -346,12 +382,17 @@ func AutoModelfile(path string, config *ModelfileGenConfig) (Modelfile, error) {
346382
return nil, err
347383
}
348384

349-
// Get the model config from the config.json file
385+
// check if model files are found
386+
if mf.model.Size() == 0 {
387+
return nil, fmt.Errorf("no recognized model files found in directory - you may need to edit the Modelfile manually")
388+
}
389+
390+
// get the model config from the config.json file
350391
if err := parseModelConfig(path, mf); err != nil {
351392
return nil, err
352393
}
353394

354-
// Overwrite the modelfile configurations with the provided config values
395+
// overwrite the modelfile configurations with the provided config values
355396
overwriteModelConfig(mf, config)
356397

357398
return mf, nil
@@ -533,7 +574,7 @@ func (mf *modelfile) SaveToFile(path string) error {
533574
// generate time in the first line
534575
content += fmt.Sprintf("# Generated at %s\n", time.Now().Format(time.RFC3339))
535576

536-
// Add single value commands
577+
// add single value commands
537578
if mf.name != "" {
538579
content += "\n# Model name\n"
539580
content += fmt.Sprintf("NAME %s\n", mf.name)
@@ -563,7 +604,7 @@ func (mf *modelfile) SaveToFile(path string) error {
563604
content += fmt.Sprintf("QUANTIZATION %s\n", mf.quantization)
564605
}
565606

566-
// Add multi-value commands
607+
// add multi-value commands
567608
content += "\n# Config files (Generated from the files in the model directory)\n"
568609
content += "# Supported file types: " + strings.Join(configFilePatterns, ", ") + "\n"
569610
configs := mf.GetConfigs()
@@ -588,6 +629,6 @@ func (mf *modelfile) SaveToFile(path string) error {
588629
content += fmt.Sprintf("MODEL %s\n", model)
589630
}
590631

591-
// Write to file
632+
// write to file
592633
return os.WriteFile(path, []byte(content), 0644)
593634
}

pkg/modelfile/modelfile_test.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package modelfile
1919
import (
2020
"errors"
2121
"os"
22+
"path/filepath"
2223
"sort"
2324
"testing"
2425

@@ -247,3 +248,171 @@ name bar
247248
os.Remove(tmpfile.Name())
248249
}
249250
}
251+
252+
func TestAutoModelfile(t *testing.T) {
253+
testCases := []struct {
254+
name string
255+
files map[string]string
256+
config *ModelfileGenConfig
257+
expectErr error
258+
validate func(*testing.T, Modelfile)
259+
}{
260+
{
261+
name: "basic model directory",
262+
files: map[string]string{
263+
"config.json": `{"model_type": "llama", "transformers_version": "1.0", "torch_dtype": "float16"}`,
264+
"generation_config.json": `{}`,
265+
"tokenizer.model": "dummy content",
266+
"pytorch_model.bin": "dummy content",
267+
"model.safetensors": "dummy content",
268+
"train.py": "print('hello')",
269+
"README.md": "# Model Documentation",
270+
".git/config": "should be ignored",
271+
"__pycache__/cache.pyc": "should be ignored",
272+
},
273+
config: &ModelfileGenConfig{
274+
Name: "llama2-7b",
275+
Format: "safetensors",
276+
Paramsize: "7B",
277+
Quantization: "q4_k_m",
278+
IgnoreUnrecognized: true,
279+
},
280+
expectErr: nil,
281+
validate: func(t *testing.T, mf Modelfile) {
282+
assert := assert.New(t)
283+
284+
// Check configs (sorted)
285+
expectedConfigs := []string{
286+
"README.md",
287+
"config.json",
288+
"generation_config.json",
289+
"tokenizer.model",
290+
}
291+
configs := mf.GetConfigs()
292+
sort.Strings(configs)
293+
assert.Equal(expectedConfigs, configs)
294+
295+
// Check models (sorted)
296+
expectedModels := []string{
297+
"model.safetensors",
298+
"pytorch_model.bin",
299+
}
300+
models := mf.GetModels()
301+
sort.Strings(models)
302+
assert.Equal(expectedModels, models)
303+
304+
// Check codes (sorted)
305+
expectedCodes := []string{
306+
"train.py",
307+
}
308+
codes := mf.GetCodes()
309+
sort.Strings(codes)
310+
assert.Equal(expectedCodes, codes)
311+
312+
// Check other fields
313+
assert.Equal("llama2-7b", mf.GetName())
314+
assert.Equal("transformer", mf.GetArch()) // from config.json
315+
assert.Equal("llama", mf.GetFamily()) // from config.json
316+
assert.Equal("safetensors", mf.GetFormat())
317+
assert.Equal("7B", mf.GetParamsize())
318+
assert.Equal("float16", mf.GetPrecision()) // from config.json
319+
assert.Equal("q4_k_m", mf.GetQuantization())
320+
},
321+
},
322+
{
323+
name: "unrecognized files without ignore flag",
324+
files: map[string]string{
325+
"unknown.xyz": "some content",
326+
},
327+
config: &ModelfileGenConfig{
328+
Name: "test-model",
329+
IgnoreUnrecognized: false,
330+
},
331+
expectErr: errors.New("unknown file type: unknown.xyz - use --ignore-unrecognized to ignore, and edit the Modelfile manually"),
332+
},
333+
{
334+
name: "empty directory",
335+
files: map[string]string{},
336+
config: &ModelfileGenConfig{Name: "empty-model"},
337+
expectErr: errors.New("no recognized model files found in directory - you may need to edit the Modelfile manually"),
338+
},
339+
{
340+
name: "invalid config json",
341+
files: map[string]string{"config.json": `{"model_type": "llama", invalid json`},
342+
config: &ModelfileGenConfig{Name: "invalid-config"},
343+
expectErr: errors.New("no recognized model files found in directory - you may need to edit the Modelfile manually"),
344+
},
345+
{
346+
name: "nested directories",
347+
files: map[string]string{"config.json": `{"model_type": "llama"}`, "models/shard1.safetensors": "dummy content", "models/shard2.safetensors": "dummy content", "configs/main.json": "dummy content", "src/train.py": "print('hello')"},
348+
config: &ModelfileGenConfig{Name: "nested-model", IgnoreUnrecognized: true},
349+
expectErr: nil,
350+
validate: func(t *testing.T, mf Modelfile) {
351+
assert := assert.New(t)
352+
models := mf.GetModels()
353+
sort.Strings(models)
354+
assert.Equal([]string{"models/shard1.safetensors", "models/shard2.safetensors"}, models)
355+
356+
codes := mf.GetCodes()
357+
sort.Strings(codes)
358+
assert.Equal([]string{"src/train.py"}, codes)
359+
},
360+
},
361+
{
362+
name: "special characters in paths",
363+
files: map[string]string{
364+
"config.json": `{"model_type": "llama"}`,
365+
"model with spaces.safetensors": "dummy content",
366+
"特殊字符.bin": "dummy content",
367+
"src/test-file.py": "print('hello')",
368+
},
369+
config: &ModelfileGenConfig{
370+
Name: "special-chars",
371+
IgnoreUnrecognized: true,
372+
},
373+
validate: func(t *testing.T, mf Modelfile) {
374+
assert := assert.New(t)
375+
models := mf.GetModels()
376+
sort.Strings(models)
377+
assert.Equal([]string{"model with spaces.safetensors", "特殊字符.bin"}, models)
378+
},
379+
},
380+
}
381+
382+
for _, tc := range testCases {
383+
t.Run(tc.name, func(t *testing.T) {
384+
// Create temporary directory
385+
tmpDir, err := os.MkdirTemp("", "modelfile_test_*")
386+
assert.NoError(t, err)
387+
defer os.RemoveAll(tmpDir)
388+
389+
// Create test files
390+
for path, content := range tc.files {
391+
fullPath := filepath.Join(tmpDir, path)
392+
393+
// Create parent directories if needed
394+
err := os.MkdirAll(filepath.Dir(fullPath), 0755)
395+
assert.NoError(t, err)
396+
397+
err = os.WriteFile(fullPath, []byte(content), 0644)
398+
assert.NoError(t, err)
399+
}
400+
401+
// Run AutoModelfile
402+
mf, err := AutoModelfile(tmpDir, tc.config)
403+
404+
if tc.expectErr != nil {
405+
assert.Error(t, err)
406+
assert.Equal(t, tc.expectErr.Error(), err.Error())
407+
assert.Nil(t, mf)
408+
return
409+
}
410+
411+
assert.NoError(t, err)
412+
assert.NotNil(t, mf)
413+
414+
// Run validation
415+
tc.validate(t, mf)
416+
})
417+
}
418+
}

0 commit comments

Comments
 (0)