Skip to content

Commit fb50643

Browse files
Add support for HF model_url
Signed-off-by: Avinash Singh <[email protected]>
1 parent 11a439b commit fb50643

File tree

4 files changed

+416
-5
lines changed

4 files changed

+416
-5
lines changed

cmd/modelfile/generate.go

Lines changed: 63 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,57 @@ import (
2020
"context"
2121
"fmt"
2222
"os"
23+
"path/filepath"
2324

2425
"github.com/spf13/cobra"
2526
"github.com/spf13/viper"
2627

2728
configmodelfile "github.com/modelpack/modctl/pkg/config/modelfile"
29+
"github.com/modelpack/modctl/pkg/hfhub"
2830
"github.com/modelpack/modctl/pkg/modelfile"
2931
)
3032

3133
var generateConfig = configmodelfile.NewGenerateConfig()
3234

3335
// generateCmd represents the modelfile tools command for generating modelfile.
3436
var generateCmd = &cobra.Command{
35-
Use: "generate [flags] <path>",
36-
Short: "A command line tool for generating modelfile in the workspace, the workspace must be a directory including model files and model configuration files",
37-
Args: cobra.ExactArgs(1),
37+
Use: "generate [flags] [<path>]",
38+
Short: "Generate a modelfile from a local workspace or Hugging Face model",
39+
Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from Hugging Face.
40+
41+
The workspace must be a directory including model files and model configuration files.
42+
Alternatively, use --model_url to download a model from Hugging Face Hub.`,
43+
Example: ` # Generate from local directory
44+
modctl modelfile generate ./my-model-dir
45+
46+
# Generate from Hugging Face model URL
47+
modctl modelfile generate --model_url https://huggingface.co/meta-llama/Llama-2-7b-hf
48+
49+
# Generate from Hugging Face using short form
50+
modctl modelfile generate --model_url meta-llama/Llama-2-7b-hf
51+
52+
# Generate with custom output path
53+
modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml
54+
55+
# Generate with metadata overrides
56+
modctl modelfile generate ./my-model-dir --name my-custom-model --family llama3`,
57+
Args: cobra.MaximumNArgs(1),
3858
DisableAutoGenTag: true,
3959
SilenceUsage: true,
4060
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
4161
RunE: func(cmd *cobra.Command, args []string) error {
42-
if err := generateConfig.Convert(args[0]); err != nil {
62+
// If model_url is provided, path is optional
63+
workspace := "."
64+
if len(args) > 0 {
65+
workspace = args[0]
66+
}
67+
68+
// Validate that either path or model_url is provided
69+
if generateConfig.ModelURL == "" && len(args) == 0 {
70+
return fmt.Errorf("either <path> argument or --model_url flag must be provided")
71+
}
72+
73+
if err := generateConfig.Convert(workspace); err != nil {
4374
return err
4475
}
4576

@@ -64,6 +95,7 @@ func init() {
6495
flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory")
6596
flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace")
6697
flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile")
98+
flags.StringVar(&generateConfig.ModelURL, "model_url", "", "download model from Hugging Face (format: owner/repo or full URL)")
6799

68100
// Mark the ignore-unrecognized-file-types flag as deprecated and hidden
69101
flags.MarkDeprecated("ignore-unrecognized-file-types", "this flag will be removed in the next release")
@@ -75,7 +107,33 @@ func init() {
75107
}
76108

77109
// runGenerate runs the generate modelfile.
78-
func runGenerate(_ context.Context) error {
110+
func runGenerate(ctx context.Context) error {
111+
// If model URL is provided, download the model first
112+
if generateConfig.ModelURL != "" {
113+
fmt.Printf("Model URL provided: %s\n", generateConfig.ModelURL)
114+
115+
// Check if user is authenticated with Hugging Face
116+
if err := hfhub.CheckHuggingFaceAuth(); err != nil {
117+
return fmt.Errorf("authentication check failed: %w", err)
118+
}
119+
120+
// Create a temporary directory for downloading the model
121+
tmpDir := filepath.Join(os.TempDir(), "modctl-hf-downloads")
122+
if err := os.MkdirAll(tmpDir, 0755); err != nil {
123+
return fmt.Errorf("failed to create temporary directory: %w", err)
124+
}
125+
126+
// Download the model
127+
downloadPath, err := hfhub.DownloadModel(ctx, generateConfig.ModelURL, tmpDir)
128+
if err != nil {
129+
return fmt.Errorf("failed to download model: %w", err)
130+
}
131+
132+
// Update workspace to the downloaded model path
133+
generateConfig.Workspace = downloadPath
134+
fmt.Printf("Using downloaded model at: %s\n", downloadPath)
135+
}
136+
79137
fmt.Printf("Generating modelfile for %s\n", generateConfig.Workspace)
80138
modelfile, err := modelfile.NewModelfileByWorkspace(generateConfig.Workspace, generateConfig)
81139
if err != nil {

pkg/config/modelfile/modelfile.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ type GenerateConfig struct {
3939
ParamSize string
4040
Precision string
4141
Quantization string
42+
ModelURL string
4243
}
4344

4445
func NewGenerateConfig() *GenerateConfig {
@@ -55,6 +56,7 @@ func NewGenerateConfig() *GenerateConfig {
5556
ParamSize: "",
5657
Precision: "",
5758
Quantization: "",
59+
ModelURL: "",
5860
}
5961
}
6062

pkg/hfhub/download.go

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
/*
2+
* Copyright 2025 The CNAI Authors
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package hfhub
18+
19+
import (
20+
"context"
21+
"fmt"
22+
"io"
23+
"net/http"
24+
"net/url"
25+
"os"
26+
"os/exec"
27+
"path/filepath"
28+
"strings"
29+
)
30+
31+
const (
32+
HuggingFaceBaseURL = "https://huggingface.co"
33+
)
34+
35+
// ParseModelURL parses a Hugging Face model URL and extracts the owner and repository name
36+
func ParseModelURL(modelURL string) (owner, repo string, err error) {
37+
// Handle both full URLs and short-form repo names
38+
modelURL = strings.TrimSpace(modelURL)
39+
40+
// Remove trailing slashes
41+
modelURL = strings.TrimSuffix(modelURL, "/")
42+
43+
// If it's a full URL, parse it
44+
if strings.HasPrefix(modelURL, "http://") || strings.HasPrefix(modelURL, "https://") {
45+
u, err := url.Parse(modelURL)
46+
if err != nil {
47+
return "", "", fmt.Errorf("invalid URL: %w", err)
48+
}
49+
50+
// Expected format: https://huggingface.co/owner/repo
51+
parts := strings.Split(strings.Trim(u.Path, "/"), "/")
52+
if len(parts) < 2 {
53+
return "", "", fmt.Errorf("invalid Hugging Face URL format, expected https://huggingface.co/owner/repo")
54+
}
55+
56+
owner = parts[0]
57+
repo = parts[1]
58+
} else {
59+
// Handle short-form like "owner/repo"
60+
parts := strings.Split(modelURL, "/")
61+
if len(parts) != 2 {
62+
return "", "", fmt.Errorf("invalid model identifier, expected format: owner/repo")
63+
}
64+
65+
owner = parts[0]
66+
repo = parts[1]
67+
}
68+
69+
if owner == "" || repo == "" {
70+
return "", "", fmt.Errorf("owner and repository name cannot be empty")
71+
}
72+
73+
return owner, repo, nil
74+
}
75+
76+
// DownloadModel downloads a model from Hugging Face using the huggingface-cli
77+
// It assumes the user is already logged in via `huggingface-cli login`
78+
func DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) {
79+
owner, repo, err := ParseModelURL(modelURL)
80+
if err != nil {
81+
return "", err
82+
}
83+
84+
repoID := fmt.Sprintf("%s/%s", owner, repo)
85+
86+
// Check if huggingface-cli is available
87+
if _, err := exec.LookPath("huggingface-cli"); err != nil {
88+
return "", fmt.Errorf("huggingface-cli not found in PATH. Please install it using: pip install huggingface_hub[cli]")
89+
}
90+
91+
// Create destination directory if it doesn't exist
92+
if err := os.MkdirAll(destDir, 0755); err != nil {
93+
return "", fmt.Errorf("failed to create destination directory: %w", err)
94+
}
95+
96+
// Construct the download path
97+
downloadPath := filepath.Join(destDir, repo)
98+
99+
// Use huggingface-cli to download the model
100+
// The --local-dir-use-symlinks=False flag ensures files are copied, not symlinked
101+
cmd := exec.CommandContext(ctx, "huggingface-cli", "download", repoID, "--local-dir", downloadPath, "--local-dir-use-symlinks", "False")
102+
103+
cmd.Stdout = os.Stdout
104+
cmd.Stderr = os.Stderr
105+
106+
fmt.Printf("Downloading model %s to %s...\n", repoID, downloadPath)
107+
108+
if err := cmd.Run(); err != nil {
109+
return "", fmt.Errorf("failed to download model using huggingface-cli: %w", err)
110+
}
111+
112+
fmt.Printf("Successfully downloaded model to %s\n", downloadPath)
113+
114+
return downloadPath, nil
115+
}
116+
117+
// CheckHuggingFaceAuth checks if the user is authenticated with Hugging Face
118+
func CheckHuggingFaceAuth() error {
119+
// Try to find the HF token
120+
token := os.Getenv("HF_TOKEN")
121+
if token != "" {
122+
return nil
123+
}
124+
125+
// Check if the token file exists
126+
homeDir, err := os.UserHomeDir()
127+
if err != nil {
128+
return fmt.Errorf("failed to get user home directory: %w", err)
129+
}
130+
131+
tokenPath := filepath.Join(homeDir, ".huggingface", "token")
132+
if _, err := os.Stat(tokenPath); err == nil {
133+
return nil
134+
}
135+
136+
// Try using whoami command
137+
if _, err := exec.LookPath("huggingface-cli"); err == nil {
138+
cmd := exec.Command("huggingface-cli", "whoami")
139+
if err := cmd.Run(); err == nil {
140+
return nil
141+
}
142+
}
143+
144+
return fmt.Errorf("not authenticated with Hugging Face. Please run: huggingface-cli login")
145+
}
146+
147+
// GetToken retrieves the Hugging Face token from environment or token file
148+
func GetToken() (string, error) {
149+
// First check environment variable
150+
token := os.Getenv("HF_TOKEN")
151+
if token != "" {
152+
return token, nil
153+
}
154+
155+
// Then check the token file
156+
homeDir, err := os.UserHomeDir()
157+
if err != nil {
158+
return "", fmt.Errorf("failed to get user home directory: %w", err)
159+
}
160+
161+
tokenPath := filepath.Join(homeDir, ".huggingface", "token")
162+
data, err := os.ReadFile(tokenPath)
163+
if err != nil {
164+
return "", fmt.Errorf("failed to read token file: %w", err)
165+
}
166+
167+
return strings.TrimSpace(string(data)), nil
168+
}
169+
170+
// DownloadFile downloads a single file from Hugging Face
171+
func DownloadFile(ctx context.Context, owner, repo, filename, destPath string) error {
172+
token, err := GetToken()
173+
if err != nil {
174+
return fmt.Errorf("failed to get Hugging Face token: %w", err)
175+
}
176+
177+
// Construct the download URL
178+
// Format: https://huggingface.co/{owner}/{repo}/resolve/main/{filename}
179+
fileURL := fmt.Sprintf("%s/%s/%s/resolve/main/%s", HuggingFaceBaseURL, owner, repo, filename)
180+
181+
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
182+
if err != nil {
183+
return fmt.Errorf("failed to create request: %w", err)
184+
}
185+
186+
// Add authorization header
187+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
188+
189+
client := &http.Client{}
190+
resp, err := client.Do(req)
191+
if err != nil {
192+
return fmt.Errorf("failed to download file: %w", err)
193+
}
194+
defer resp.Body.Close()
195+
196+
if resp.StatusCode != http.StatusOK {
197+
return fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
198+
}
199+
200+
// Create destination directory
201+
destDir := filepath.Dir(destPath)
202+
if err := os.MkdirAll(destDir, 0755); err != nil {
203+
return fmt.Errorf("failed to create destination directory: %w", err)
204+
}
205+
206+
// Create the destination file
207+
outFile, err := os.Create(destPath)
208+
if err != nil {
209+
return fmt.Errorf("failed to create destination file: %w", err)
210+
}
211+
defer outFile.Close()
212+
213+
// Copy the content
214+
_, err = io.Copy(outFile, resp.Body)
215+
if err != nil {
216+
return fmt.Errorf("failed to write file: %w", err)
217+
}
218+
219+
return nil
220+
}

0 commit comments

Comments
 (0)