Skip to content

Commit eac4ac9

Browse files
Add support for modelFile generate via model providers model_url (#329)
* Add support for HF model_url Signed-off-by: Avinash Singh <[email protected]> * optimise as per gemini's review Signed-off-by: Avinash Singh <[email protected]> * add modelprovider interface and providers Signed-off-by: Avinash Singh <[email protected]> * optimise code Signed-off-by: Avinash Singh <[email protected]> * add optional param for download-dir Signed-off-by: Avinash Singh <[email protected]> --------- Signed-off-by: Avinash Singh <[email protected]>
1 parent a44dae3 commit eac4ac9

File tree

11 files changed

+1302
-5
lines changed

11 files changed

+1302
-5
lines changed

cmd/modelfile/generate.go

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,65 @@ import (
2626

2727
configmodelfile "github.com/modelpack/modctl/pkg/config/modelfile"
2828
"github.com/modelpack/modctl/pkg/modelfile"
29+
"github.com/modelpack/modctl/pkg/modelprovider"
2930
)
3031

3132
var generateConfig = configmodelfile.NewGenerateConfig()
3233

3334
// generateCmd represents the modelfile tools command for generating modelfile.
3435
var generateCmd = &cobra.Command{
35-
Use: "generate [flags] <path>",
36-
Short: "A command line tool for generating modelfile in the workspace, the workspace must be a directory including model files and model configuration files",
37-
Args: cobra.ExactArgs(1),
36+
Use: "generate [flags] [<path>]",
37+
Short: "Generate a modelfile from a local workspace or remote model provider",
38+
Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from a supported provider.
39+
40+
The workspace must be a directory including model files and model configuration files.
41+
Alternatively, use --model-url to download a model from a supported provider (e.g., HuggingFace, ModelScope).
42+
43+
For short-form URLs (owner/repo), you must explicitly specify the provider using --provider flag.
44+
Full URLs with domain names will auto-detect the provider.`,
45+
Example: ` # Generate from local directory
46+
modctl modelfile generate ./my-model-dir
47+
48+
# Generate from Hugging Face using full URL (auto-detects provider)
49+
modctl modelfile generate --model-url https://huggingface.co/meta-llama/Llama-2-7b-hf
50+
51+
# Generate from Hugging Face using short form (requires --provider)
52+
modctl modelfile generate --model-url meta-llama/Llama-2-7b-hf --provider huggingface
53+
54+
# Generate from ModelScope using full URL (auto-detects provider)
55+
modctl modelfile generate --model-url https://modelscope.cn/models/qwen/Qwen-7B
56+
57+
# Generate from ModelScope using short form (requires --provider)
58+
modctl modelfile generate --model-url qwen/Qwen-7B --provider modelscope
59+
60+
# Generate with custom download directory
61+
modctl modelfile generate --model-url meta-llama/Llama-2-7b-hf --provider huggingface --download-dir $HOME/models
62+
63+
# Generate with custom output path
64+
modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml
65+
66+
# Generate with metadata overrides
67+
modctl modelfile generate ./my-model-dir --name my-custom-model --family llama3`,
68+
Args: cobra.MaximumNArgs(1),
3869
DisableAutoGenTag: true,
3970
SilenceUsage: true,
4071
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
4172
RunE: func(cmd *cobra.Command, args []string) error {
42-
if err := generateConfig.Convert(args[0]); err != nil {
73+
// If model-url is provided, path is optional
74+
workspace := "."
75+
if len(args) > 0 {
76+
workspace = args[0]
77+
}
78+
79+
// Validate that either path or model-url is provided
80+
if generateConfig.ModelURL != "" && len(args) > 0 {
81+
return fmt.Errorf("the <path> argument and the --model-url flag are mutually exclusive")
82+
}
83+
if generateConfig.ModelURL == "" && len(args) == 0 {
84+
return fmt.Errorf("either a <path> argument or the --model-url flag must be provided")
85+
}
86+
87+
if err := generateConfig.Convert(workspace); err != nil {
4388
return err
4489
}
4590

@@ -64,6 +109,9 @@ func init() {
64109
flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory")
65110
flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace")
66111
flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile")
112+
flags.StringVar(&generateConfig.ModelURL, "model-url", "", "download model from a supported provider (full URL or short-form with --provider)")
113+
flags.StringVarP(&generateConfig.Provider, "provider", "p", "", "explicitly specify the provider for short-form URLs (huggingface, modelscope)")
114+
flags.StringVar(&generateConfig.DownloadDir, "download-dir", "", "custom directory for downloading models (default: system temp directory)")
67115
flags.StringArrayVar(&generateConfig.ExcludePatterns, "exclude", []string{}, "specify glob patterns to exclude files/directories (e.g. *.log, checkpoints/*)")
68116

69117
// Mark the ignore-unrecognized-file-types flag as deprecated and hidden
@@ -76,7 +124,65 @@ func init() {
76124
}
77125

78126
// runGenerate runs the generate modelfile.
79-
func runGenerate(_ context.Context) error {
127+
func runGenerate(ctx context.Context) error {
128+
// If model URL is provided, download the model first
129+
if generateConfig.ModelURL != "" {
130+
fmt.Printf("Model URL provided: %s\n", generateConfig.ModelURL)
131+
132+
// Get the appropriate provider for this URL
133+
registry := modelprovider.GetRegistry()
134+
provider, err := registry.SelectProvider(generateConfig.ModelURL, generateConfig.Provider)
135+
if err != nil {
136+
return fmt.Errorf("failed to select provider: %w", err)
137+
}
138+
139+
fmt.Printf("Using provider: %s\n", provider.Name())
140+
141+
// Check if user is authenticated with the provider
142+
if err := provider.CheckAuth(); err != nil {
143+
return fmt.Errorf("%s authentication check failed: %w", provider.Name(), err)
144+
}
145+
146+
// Determine the download directory
147+
var downloadDir string
148+
var cleanupDir bool
149+
150+
if generateConfig.DownloadDir != "" {
151+
// Use user-specified directory
152+
downloadDir = generateConfig.DownloadDir
153+
cleanupDir = false
154+
155+
// Create the directory if it doesn't exist
156+
if err := os.MkdirAll(downloadDir, 0755); err != nil {
157+
return fmt.Errorf("failed to create download directory: %w", err)
158+
}
159+
fmt.Printf("Using custom download directory: %s\n", downloadDir)
160+
} else {
161+
// Create a temporary directory for downloading the model
162+
tmpDir, err := os.MkdirTemp("", "modctl-model-downloads-*")
163+
if err != nil {
164+
return fmt.Errorf("failed to create temporary directory: %w", err)
165+
}
166+
downloadDir = tmpDir
167+
cleanupDir = true
168+
}
169+
170+
// Clean up the directory only if it was a temporary directory
171+
if cleanupDir {
172+
defer os.RemoveAll(downloadDir)
173+
}
174+
175+
// Download the model
176+
downloadPath, err := provider.DownloadModel(ctx, generateConfig.ModelURL, downloadDir)
177+
if err != nil {
178+
return fmt.Errorf("failed to download model from %s: %w", provider.Name(), err)
179+
}
180+
181+
// Update workspace to the downloaded model path
182+
generateConfig.Workspace = downloadPath
183+
fmt.Printf("Using downloaded model at: %s\n", downloadPath)
184+
}
185+
80186
fmt.Printf("Generating modelfile for %s\n", generateConfig.Workspace)
81187
modelfile, err := modelfile.NewModelfileByWorkspace(generateConfig.Workspace, generateConfig)
82188
if err != nil {

pkg/config/modelfile/modelfile.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ type GenerateConfig struct {
3939
ParamSize string
4040
Precision string
4141
Quantization string
42+
ModelURL string
43+
Provider string // Explicit provider for short-form URLs (e.g., "huggingface", "modelscope")
44+
DownloadDir string // Custom directory for downloading models (optional)
4245
ExcludePatterns []string
4346
}
4447

@@ -56,6 +59,9 @@ func NewGenerateConfig() *GenerateConfig {
5659
ParamSize: "",
5760
Precision: "",
5861
Quantization: "",
62+
ModelURL: "",
63+
Provider: "",
64+
DownloadDir: "",
5965
ExcludePatterns: []string{},
6066
}
6167
}
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
/*
2+
* Copyright 2025 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package huggingface
18+
19+
import (
20+
"fmt"
21+
"io"
22+
"net/url"
23+
"os"
24+
"os/exec"
25+
"path/filepath"
26+
"strings"
27+
)
28+
29+
const (
30+
huggingFaceBaseURL = "https://huggingface.co"
31+
)
32+
33+
// parseModelURL parses a HuggingFace model URL and extracts the owner and repository name
34+
func parseModelURL(modelURL string) (owner, repo string, err error) {
35+
// Handle both full URLs and short-form repo names
36+
modelURL = strings.TrimSpace(modelURL)
37+
38+
// Remove trailing slashes
39+
modelURL = strings.TrimSuffix(modelURL, "/")
40+
41+
// If it's a full URL, parse it
42+
if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") {
43+
u, err := url.Parse(modelURL)
44+
if err != nil {
45+
return "", "", fmt.Errorf("invalid URL: %w", err)
46+
}
47+
48+
// Expected format: https://huggingface.co/owner/repo
49+
parts := strings.Split(strings.Trim(u.Path, "/"), "/")
50+
if len(parts) < 2 {
51+
return "", "", fmt.Errorf("invalid HuggingFace URL format, expected https://huggingface.co/owner/repo")
52+
}
53+
54+
owner = parts[0]
55+
repo = parts[1]
56+
} else {
57+
// Handle short-form like "owner/repo"
58+
parts := strings.Split(modelURL, "/")
59+
if len(parts) != 2 {
60+
return "", "", fmt.Errorf("invalid model identifier, expected format: owner/repo")
61+
}
62+
63+
owner = parts[0]
64+
repo = parts[1]
65+
}
66+
67+
if owner == "" || repo == "" {
68+
return "", "", fmt.Errorf("owner and repository name cannot be empty")
69+
}
70+
71+
return owner, repo, nil
72+
}
73+
74+
// checkHuggingFaceAuth checks if the user is authenticated with HuggingFace
75+
func checkHuggingFaceAuth() error {
76+
// Try to find the HF token
77+
token := os.Getenv("HF_TOKEN")
78+
if token != "" {
79+
return nil
80+
}
81+
82+
// Check if the token file exists
83+
homeDir, err := os.UserHomeDir()
84+
if err != nil {
85+
return fmt.Errorf("failed to get user home directory: %w", err)
86+
}
87+
88+
tokenPath := filepath.Join(homeDir, ".huggingface", "token")
89+
if _, err := os.Stat(tokenPath); err == nil {
90+
return nil
91+
}
92+
93+
// Try using whoami command
94+
if _, err := exec.LookPath("huggingface-cli"); err == nil {
95+
cmd := exec.Command("huggingface-cli", "whoami")
96+
cmd.Stdout = io.Discard
97+
cmd.Stderr = io.Discard
98+
if err := cmd.Run(); err == nil {
99+
return nil
100+
}
101+
}
102+
103+
return fmt.Errorf("not authenticated with HuggingFace. Please run: huggingface-cli login")
104+
}
105+
106+
// getToken retrieves the HuggingFace token from environment or token file
107+
func getToken() (string, error) {
108+
// First check environment variable
109+
token := os.Getenv("HF_TOKEN")
110+
if token != "" {
111+
return token, nil
112+
}
113+
114+
// Then check the token file
115+
homeDir, err := os.UserHomeDir()
116+
if err != nil {
117+
return "", fmt.Errorf("failed to get user home directory: %w", err)
118+
}
119+
120+
tokenPath := filepath.Join(homeDir, ".huggingface", "token")
121+
data, err := os.ReadFile(tokenPath)
122+
if err != nil {
123+
return "", fmt.Errorf("failed to read token file: %w", err)
124+
}
125+
126+
return strings.TrimSpace(string(data)), nil
127+
}

0 commit comments

Comments
 (0)