Skip to content

Commit bd2dd97

Browse files
qiumuyangqmy
andauthored
feat: support modelfile generate with exclude patterns (modelpack#332)
* feat: enhance modelfile generate with exclude patterns Signed-off-by: qmy <qiumuyang.qmy@antgroup.com> * docs(getting-started): update modelfile generate usage Signed-off-by: qmy <qiumuyang.qmy@antgroup.com> --------- Signed-off-by: qmy <qiumuyang.qmy@antgroup.com> Co-authored-by: qmy <qiumuyang.qmy@antgroup.com>
1 parent e1d2c95 commit bd2dd97

File tree

6 files changed

+233
-10
lines changed

6 files changed

+233
-10
lines changed

cmd/modelfile/generate.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ func init() {
6464
flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory")
6565
flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace")
6666
flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile")
67+
flags.StringArrayVar(&generateConfig.ExcludePatterns, "exclude", []string{}, "specify glob patterns to exclude files/directories (e.g. *.log, checkpoints/*)")
6768

6869
// Mark the ignore-unrecognized-file-types flag as deprecated and hidden
6970
flags.MarkDeprecated("ignore-unrecognized-file-types", "this flag will be removed in the next release")

docs/getting-started.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ directory(workspace).
3131
$ modctl modelfile generate .
3232
```
3333

34+
If you want to exclude specific files from the model artifact (such as checkpoint directories),
35+
you can use the `--exclude` option to specify the file path glob pattern.
36+
Note that only basic glob syntax (`*`, `?`, `[]`) is supported; advanced features like `**` for recursive matching are not.
37+
38+
```shell
39+
$ modctl modelfile generate . --exclude 'checkpoint-*'
40+
```
41+
3442
### Build
3543

3644
Build the model artifact you need to prepare a Modelfile describe your expected layout of the model artifact in your model repo.

pkg/config/modelfile/modelfile.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ type GenerateConfig struct {
3939
ParamSize string
4040
Precision string
4141
Quantization string
42+
ExcludePatterns []string
4243
}
4344

4445
func NewGenerateConfig() *GenerateConfig {
@@ -55,6 +56,7 @@ func NewGenerateConfig() *GenerateConfig {
5556
ParamSize: "",
5657
Precision: "",
5758
Quantization: "",
59+
ExcludePatterns: []string{},
5860
}
5961
}
6062

pkg/modelfile/modelfile.go

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ func NewModelfileByWorkspace(workspace string, config *configmodelfile.GenerateC
209209
return nil, err
210210
}
211211

212-
if err := mf.generateByWorkspace(); err != nil {
212+
if err := mf.generateByWorkspace(config); err != nil {
213213
return nil, err
214214
}
215215

@@ -252,11 +252,17 @@ func (mf *modelfile) validateWorkspace() error {
252252
}
253253

254254
// generateByWorkspace generates the modelfile by the workspace's files.
255-
func (mf *modelfile) generateByWorkspace() error {
255+
func (mf *modelfile) generateByWorkspace(config *configmodelfile.GenerateConfig) error {
256256
// Initialize counters for workspace limits validation
257257
var fileCount int
258258
var totalSize int64
259259

260+
// Initialize exclude patterns
261+
filter, err := NewPathFilter(config.ExcludePatterns...)
262+
if err != nil {
263+
return err
264+
}
265+
260266
// Walk the path and get the files.
261267
if err := filepath.Walk(mf.workspace, func(path string, info os.FileInfo, err error) error {
262268
if err != nil {
@@ -265,8 +271,14 @@ func (mf *modelfile) generateByWorkspace() error {
265271

266272
filename := info.Name()
267273

268-
// Skip hidden and skippable files/directories.
269-
if isSkippable(filename) {
274+
// Get relative path from the base directory.
275+
relPath, err := filepath.Rel(mf.workspace, path)
276+
if err != nil {
277+
return err
278+
}
279+
280+
// Skip hidden, skippable, and excluded files/directories.
281+
if isSkippable(filename) || filter.Match(relPath) {
270282
if info.IsDir() {
271283
return filepath.SkipDir
272284
}
@@ -298,12 +310,6 @@ func (mf *modelfile) generateByWorkspace() error {
298310
return fmt.Errorf("workspace exceeds maximum total size limit of %d bytes (%s)", MaxTotalWorkspaceSize, formatBytes(MaxTotalWorkspaceSize))
299311
}
300312

301-
// Get relative path from the base directory.
302-
relPath, err := filepath.Rel(mf.workspace, path)
303-
if err != nil {
304-
return err
305-
}
306-
307313
switch {
308314
case IsFileType(filename, ConfigFilePatterns):
309315
mf.config.Add(relPath)

pkg/modelfile/path_filter.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package modelfile
2+
3+
import (
4+
"fmt"
5+
"path/filepath"
6+
"strings"
7+
)
8+
9+
type PathFilter struct {
10+
patterns []string
11+
}
12+
13+
func NewPathFilter(patterns ...string) (*PathFilter, error) {
14+
var cleaned []string
15+
for _, p := range patterns {
16+
// validate the pattern
17+
if _, err := filepath.Match(p, ""); err != nil {
18+
return nil, fmt.Errorf("invalid exclude pattern: %q", p)
19+
}
20+
// since filepath.Walk never returns a path with trailing separator, we need to remove separator from patterns
21+
cleaned = append(cleaned, strings.TrimRight(p, string(filepath.Separator)))
22+
}
23+
return &PathFilter{patterns: cleaned}, nil
24+
}
25+
26+
func (pf *PathFilter) Match(path string) bool {
27+
if len(pf.patterns) == 0 {
28+
return false
29+
}
30+
31+
for _, pattern := range pf.patterns {
32+
matched, err := filepath.Match(pattern, path)
33+
if err != nil {
34+
// The only possible returned error is ErrBadPattern
35+
// which we checked when creating the filter
36+
return false
37+
}
38+
if matched {
39+
return true
40+
}
41+
}
42+
43+
return false
44+
}

pkg/modelfile/path_filter_test.go

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
package modelfile
2+
3+
import (
4+
"fmt"
5+
"testing"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestNewPathFilter(t *testing.T) {
12+
testcases := []struct {
13+
name string
14+
input []string
15+
expected []string
16+
expectError bool
17+
errorMsg string
18+
}{
19+
{
20+
name: "normal patterns",
21+
input: []string{"*.log", "checkpoint*/"},
22+
expected: []string{"*.log", "checkpoint*"},
23+
},
24+
{
25+
name: "invalid pattern",
26+
input: []string{"*.log", "[invalid"},
27+
expectError: true,
28+
errorMsg: `invalid exclude pattern: "[invalid"`,
29+
},
30+
}
31+
32+
for _, tc := range testcases {
33+
t.Run(tc.name, func(t *testing.T) {
34+
filter, err := NewPathFilter(tc.input...)
35+
36+
if tc.expectError {
37+
require.Error(t, err, "Expected an error for input: %q", tc.input)
38+
assert.Contains(t, err.Error(), tc.errorMsg)
39+
assert.Nil(t, filter)
40+
return
41+
}
42+
43+
require.NoError(t, err, "Did not expect an error for input: %q", tc.input)
44+
require.NotNil(t, filter)
45+
assert.Equal(t, tc.expected, filter.patterns)
46+
})
47+
}
48+
}
49+
50+
func TestPathFilter_Matches(t *testing.T) {
51+
testcases := []struct {
52+
filterName string
53+
patterns []string
54+
tests []struct {
55+
desc string
56+
path string
57+
expected bool
58+
}
59+
}{
60+
{
61+
filterName: "Empty_Filter",
62+
patterns: []string{},
63+
tests: []struct {
64+
desc string
65+
path string
66+
expected bool
67+
}{
68+
{"any file", "any/path/file.txt", false},
69+
{"root file", "main.go", false},
70+
{"empty path", "", false},
71+
},
72+
},
73+
{
74+
filterName: "Single_Asterisk_Filter",
75+
patterns: []string{"*.log"},
76+
tests: []struct {
77+
desc string
78+
path string
79+
expected bool
80+
}{
81+
{"matches a simple log file", "debug.log", true},
82+
{"matches a hidden log file", ".config.log", true},
83+
{"does not match if not at end", "debug.log.old", false},
84+
{"does not match different extension", "main.go", false},
85+
{"does not cross path separator", "logs/debug.log", false},
86+
},
87+
},
88+
{
89+
filterName: "Directory_Wildcard_Filter",
90+
patterns: []string{"build/*"},
91+
tests: []struct {
92+
desc string
93+
path string
94+
expected bool
95+
}{
96+
{"matches file directly inside", "build/app", true},
97+
{"matches hidden file inside", "build/.config", true},
98+
{"does not match the directory itself", "build", false},
99+
{"does not match nested files", "build/assets/icon.png", false},
100+
},
101+
},
102+
{
103+
// Since filepath.Match does not support **, the behavior is the same as Directory_Wildcard_Filter
104+
filterName: "Directory_Double_Asterisk_Filter",
105+
patterns: []string{"build/**"},
106+
tests: []struct {
107+
desc string
108+
path string
109+
expected bool
110+
}{
111+
{"matches file directly inside", "build/app", true},
112+
{"matches hidden file inside", "build/.config", true},
113+
{"does not match the directory itself", "build", false},
114+
{"does not match nested files", "build/assets/icon.png", false},
115+
},
116+
},
117+
{
118+
filterName: "Directory_Filter",
119+
patterns: []string{"checkpoint/"},
120+
tests: []struct {
121+
desc string
122+
path string
123+
expected bool
124+
}{
125+
{"matches the directory itself", "checkpoint", true},
126+
{"match file inside", "checkpoint/file.py", false},
127+
},
128+
},
129+
{
130+
filterName: "Complex_Filter_With_Multiple_Patterns",
131+
patterns: []string{"*.tmp", ".git*", "build/"},
132+
tests: []struct {
133+
desc string
134+
path string
135+
expected bool
136+
}{
137+
{"matches a .tmp file", "temp.tmp", true},
138+
{"matches .git directory", ".git", true},
139+
{"matches .gitignore file", ".gitignore", true},
140+
{"matches build directory exactly", "build", true},
141+
{"does not cross separator", "data/cache.tmp", false},
142+
{"does not match file inside build/", "build/app.js", false},
143+
{"does not match src file", "src/main.go", false},
144+
},
145+
},
146+
}
147+
148+
for _, tc := range testcases {
149+
t.Run(tc.filterName, func(t *testing.T) {
150+
filter, err := NewPathFilter(tc.patterns...)
151+
require.NoError(t, err, "Filter creation with patterns %q failed", tc.patterns)
152+
require.NotNil(t, filter)
153+
154+
for _, test := range tc.tests {
155+
t.Run(test.desc, func(t *testing.T) {
156+
result := filter.Match(test.path)
157+
assert.Equal(t, test.expected, result, fmt.Sprintf("Path: %q", test.path))
158+
})
159+
}
160+
})
161+
}
162+
}

0 commit comments

Comments
 (0)