Skip to content

Commit f455aed

Browse files
optimise code
Signed-off-by: Avinash Singh <[email protected]>
1 parent 6d7fb90 commit f455aed

File tree

10 files changed

+163
-172
lines changed

10 files changed

+163
-172
lines changed

cmd/modelfile/generate.go

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,18 +38,24 @@ var generateCmd = &cobra.Command{
3838
Long: `Generate a modelfile from either a local directory containing model files or by downloading a model from a supported provider.
3939
4040
The workspace must be a directory including model files and model configuration files.
41-
Alternatively, use --model_url to download a model from a supported provider (e.g., HuggingFace, ModelScope).`,
41+
Alternatively, use --model-url to download a model from a supported provider (e.g., HuggingFace, ModelScope).
42+
43+
For short-form URLs (owner/repo), you must explicitly specify the provider using --provider flag.
44+
Full URLs with domain names will auto-detect the provider.`,
4245
Example: ` # Generate from local directory
4346
modctl modelfile generate ./my-model-dir
4447
45-
# Generate from Hugging Face model URL
46-
modctl modelfile generate --model_url https://huggingface.co/meta-llama/Llama-2-7b-hf
48+
# Generate from Hugging Face using full URL (auto-detects provider)
49+
modctl modelfile generate --model-url https://huggingface.co/meta-llama/Llama-2-7b-hf
50+
51+
# Generate from Hugging Face using short form (requires --provider)
52+
modctl modelfile generate --model-url meta-llama/Llama-2-7b-hf --provider huggingface
4753
48-
# Generate from Hugging Face using short form
49-
modctl modelfile generate --model_url meta-llama/Llama-2-7b-hf
54+
# Generate from ModelScope using full URL (auto-detects provider)
55+
modctl modelfile generate --model-url https://modelscope.cn/models/qwen/Qwen-7B
5056
51-
# Generate from ModelScope
52-
modctl modelfile generate --model_url https://modelscope.cn/models/qwen/Qwen-7B
57+
# Generate from ModelScope using short form (requires --provider)
58+
modctl modelfile generate --model-url qwen/Qwen-7B --provider modelscope
5359
5460
# Generate with custom output path
5561
modctl modelfile generate ./my-model-dir --output ./output/modelfile.yaml
@@ -61,18 +67,18 @@ Alternatively, use --model_url to download a model from a supported provider (e.
6167
SilenceUsage: true,
6268
FParseErrWhitelist: cobra.FParseErrWhitelist{UnknownFlags: true},
6369
RunE: func(cmd *cobra.Command, args []string) error {
64-
// If model_url is provided, path is optional
70+
// If model-url is provided, path is optional
6571
workspace := "."
6672
if len(args) > 0 {
6773
workspace = args[0]
6874
}
6975

70-
// Validate that either path or model_url is provided
76+
// Validate that either path or model-url is provided
7177
if generateConfig.ModelURL != "" && len(args) > 0 {
72-
return fmt.Errorf("the <path> argument and the --model_url flag are mutually exclusive")
78+
return fmt.Errorf("the <path> argument and the --model-url flag are mutually exclusive")
7379
}
7480
if generateConfig.ModelURL == "" && len(args) == 0 {
75-
return fmt.Errorf("either a <path> argument or the --model_url flag must be provided")
81+
return fmt.Errorf("either a <path> argument or the --model-url flag must be provided")
7682
}
7783

7884
if err := generateConfig.Convert(workspace); err != nil {
@@ -100,7 +106,8 @@ func init() {
100106
flags.StringVarP(&generateConfig.Output, "output", "O", ".", "specify the output path of modelfilem, must be a directory")
101107
flags.BoolVar(&generateConfig.IgnoreUnrecognizedFileTypes, "ignore-unrecognized-file-types", false, "ignore the unrecognized file types in the workspace")
102108
flags.BoolVar(&generateConfig.Overwrite, "overwrite", false, "overwrite the existing modelfile")
103-
flags.StringVar(&generateConfig.ModelURL, "model_url", "", "download model from a supported provider (HuggingFace: owner/repo or full URL, ModelScope: full URL)")
109+
flags.StringVar(&generateConfig.ModelURL, "model-url", "", "download model from a supported provider (full URL or short-form with --provider)")
110+
flags.StringVarP(&generateConfig.Provider, "provider", "p", "", "explicitly specify the provider for short-form URLs (huggingface, modelscope)")
104111

105112
// Mark the ignore-unrecognized-file-types flag as deprecated and hidden
106113
flags.MarkDeprecated("ignore-unrecognized-file-types", "this flag will be removed in the next release")
@@ -118,10 +125,10 @@ func runGenerate(ctx context.Context) error {
118125
fmt.Printf("Model URL provided: %s\n", generateConfig.ModelURL)
119126

120127
// Get the appropriate provider for this URL
121-
registry := modelprovider.NewRegistry()
122-
provider, err := registry.GetProvider(generateConfig.ModelURL)
128+
registry := modelprovider.GetRegistry()
129+
provider, err := registry.SelectProvider(generateConfig.ModelURL, generateConfig.Provider)
123130
if err != nil {
124-
return fmt.Errorf("unsupported model URL: %w", err)
131+
return fmt.Errorf("failed to select provider: %w", err)
125132
}
126133

127134
fmt.Printf("Using provider: %s\n", provider.Name())

pkg/config/modelfile/modelfile.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ type GenerateConfig struct {
4040
Precision string
4141
Quantization string
4242
ModelURL string
43+
Provider string // Explicit provider for short-form URLs (e.g., "huggingface", "modelscope")
4344
}
4445

4546
func NewGenerateConfig() *GenerateConfig {
@@ -57,6 +58,7 @@ func NewGenerateConfig() *GenerateConfig {
5758
Precision: "",
5859
Quantization: "",
5960
ModelURL: "",
61+
Provider: "",
6062
}
6163
}
6264

pkg/modelprovider/huggingface/downloader.go

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,8 @@
1717
package huggingface
1818

1919
import (
20-
"context"
2120
"fmt"
2221
"io"
23-
"net/http"
2422
"net/url"
2523
"os"
2624
"os/exec"
@@ -127,55 +125,3 @@ func getToken() (string, error) {
127125

128126
return strings.TrimSpace(string(data)), nil
129127
}
130-
131-
// downloadFile downloads a single file from HuggingFace
132-
func downloadFile(ctx context.Context, owner, repo, filename, destPath string) error {
133-
token, err := getToken()
134-
if err != nil {
135-
return fmt.Errorf("failed to get HuggingFace token: %w", err)
136-
}
137-
138-
// Construct the download URL
139-
// Format: https://huggingface.co/{owner}/{repo}/resolve/main/{filename}
140-
fileURL := fmt.Sprintf("%s/%s/%s/resolve/main/%s", huggingFaceBaseURL, owner, repo, filename)
141-
142-
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
143-
if err != nil {
144-
return fmt.Errorf("failed to create request: %w", err)
145-
}
146-
147-
// Add authorization header
148-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
149-
150-
client := &http.Client{}
151-
resp, err := client.Do(req)
152-
if err != nil {
153-
return fmt.Errorf("failed to download file: %w", err)
154-
}
155-
defer resp.Body.Close()
156-
157-
if resp.StatusCode != http.StatusOK {
158-
return fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
159-
}
160-
161-
// Create destination directory
162-
destDir := filepath.Dir(destPath)
163-
if err := os.MkdirAll(destDir, 0755); err != nil {
164-
return fmt.Errorf("failed to create destination directory: %w", err)
165-
}
166-
167-
// Create the destination file
168-
outFile, err := os.Create(destPath)
169-
if err != nil {
170-
return fmt.Errorf("failed to create destination file: %w", err)
171-
}
172-
defer outFile.Close()
173-
174-
// Copy the content
175-
_, err = io.Copy(outFile, resp.Body)
176-
if err != nil {
177-
return fmt.Errorf("failed to write file: %w", err)
178-
}
179-
180-
return nil
181-
}

pkg/modelprovider/huggingface/downloader_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,9 +130,9 @@ func TestProvider_SupportsURL(t *testing.T) {
130130
want: true,
131131
},
132132
{
133-
name: "short form repo",
133+
name: "short form repo (requires explicit --provider)",
134134
url: "meta-llama/Llama-2-7b-hf",
135-
want: true,
135+
want: false,
136136
},
137137
{
138138
name: "ModelScope URL",

pkg/modelprovider/huggingface/provider.go

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,22 +39,13 @@ func (p *Provider) Name() string {
3939
}
4040

4141
// SupportsURL checks if this provider can handle the given URL
42-
// It supports both full HuggingFace URLs and short-form repo identifiers
42+
// It only supports full HuggingFace URLs with the huggingface.co domain
43+
// For short-form repo identifiers (owner/repo), users must explicitly specify --provider huggingface
4344
func (p *Provider) SupportsURL(url string) bool {
4445
url = strings.TrimSpace(url)
4546

46-
// Check for full HuggingFace URLs
47-
if strings.Contains(url, "huggingface.co") {
48-
return true
49-
}
50-
51-
// Check for short-form repo identifiers (owner/repo)
52-
// Must have exactly one slash and no protocol
53-
if !strings.HasPrefix(url, "http://") && !strings.HasPrefix(url, "https://") {
54-
return strings.Count(url, "/") == 1
55-
}
56-
57-
return false
47+
// Only support full HuggingFace URLs
48+
return strings.Contains(url, "huggingface.co")
5849
}
5950

6051
// DownloadModel downloads a model from HuggingFace using the huggingface-cli

pkg/modelprovider/modelscope/downloader.go

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717
package modelscope
1818

1919
import (
20-
"context"
2120
"fmt"
22-
"io"
23-
"net/http"
2421
"net/url"
2522
"os"
2623
"os/exec"
@@ -138,58 +135,3 @@ func getToken() (string, error) {
138135

139136
return strings.TrimSpace(string(data)), nil
140137
}
141-
142-
// downloadFile downloads a single file from ModelScope
143-
func downloadFile(ctx context.Context, owner, repo, filename, destPath string) error {
144-
token, err := getToken()
145-
if err != nil {
146-
// Token is optional for public models, continue without it
147-
token = ""
148-
}
149-
150-
// Construct the download URL
151-
// Format: https://modelscope.cn/api/v1/models/{owner}/{repo}/repo?Revision=master&FilePath={filename}
152-
fileURL := fmt.Sprintf("%s/api/v1/models/%s/%s/repo?Revision=master&FilePath=%s", modelScopeBaseURL, owner, repo, filename)
153-
154-
req, err := http.NewRequestWithContext(ctx, "GET", fileURL, nil)
155-
if err != nil {
156-
return fmt.Errorf("failed to create request: %w", err)
157-
}
158-
159-
// Add authorization header if token is available
160-
if token != "" {
161-
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
162-
}
163-
164-
client := &http.Client{}
165-
resp, err := client.Do(req)
166-
if err != nil {
167-
return fmt.Errorf("failed to download file: %w", err)
168-
}
169-
defer resp.Body.Close()
170-
171-
if resp.StatusCode != http.StatusOK {
172-
return fmt.Errorf("failed to download file, status code: %d", resp.StatusCode)
173-
}
174-
175-
// Create destination directory
176-
destDir := filepath.Dir(destPath)
177-
if err := os.MkdirAll(destDir, 0755); err != nil {
178-
return fmt.Errorf("failed to create destination directory: %w", err)
179-
}
180-
181-
// Create the destination file
182-
outFile, err := os.Create(destPath)
183-
if err != nil {
184-
return fmt.Errorf("failed to create destination file: %w", err)
185-
}
186-
defer outFile.Close()
187-
188-
// Copy the content
189-
_, err = io.Copy(outFile, resp.Body)
190-
if err != nil {
191-
return fmt.Errorf("failed to write file: %w", err)
192-
}
193-
194-
return nil
195-
}

pkg/modelprovider/modelscope/provider.go

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,19 +39,13 @@ func (p *Provider) Name() string {
3939
}
4040

4141
// SupportsURL checks if this provider can handle the given URL
42-
// It supports both full ModelScope URLs and short-form repo identifiers
42+
// It only supports full ModelScope URLs with the modelscope.cn domain
43+
// For short-form repo identifiers (owner/repo), users must explicitly specify --provider modelscope
4344
func (p *Provider) SupportsURL(url string) bool {
4445
url = strings.TrimSpace(url)
4546

46-
// Check for full ModelScope URLs
47-
if strings.Contains(url, "modelscope.cn") {
48-
return true
49-
}
50-
51-
// Note: We don't auto-detect short-form for ModelScope to avoid conflicts with HuggingFace
52-
// Users should use full URLs or explicitly specify the provider
53-
54-
return false
47+
// Only support full ModelScope URLs
48+
return strings.Contains(url, "modelscope.cn")
5549
}
5650

5751
// DownloadModel downloads a model from ModelScope using the modelscope CLI

pkg/modelprovider/provider.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ type Provider interface {
2727
Name() string
2828

2929
// SupportsURL checks if this provider can handle the given model URL
30-
// This enables automatic provider detection based on URL patterns
30+
// This enables automatic provider detection based on full URL patterns (with domain)
31+
// Short-form URLs (owner/repo) require explicit provider specification via GetProviderByName
3132
SupportsURL(url string) bool
3233

3334
// DownloadModel downloads a model from the provider and returns the local path

pkg/modelprovider/registry.go

Lines changed: 42 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package modelprovider
1818

1919
import (
2020
"fmt"
21+
"sync"
2122

2223
"github.com/modelpack/modctl/pkg/modelprovider/huggingface"
2324
"github.com/modelpack/modctl/pkg/modelprovider/modelscope"
@@ -29,28 +30,57 @@ type Registry struct {
2930
providers []Provider
3031
}
3132

32-
// NewRegistry creates a new provider registry with all available providers
33-
func NewRegistry() *Registry {
34-
return &Registry{
35-
providers: []Provider{
36-
huggingface.New(),
37-
modelscope.New(),
38-
// Future providers can be added here:
39-
// civitai.New(),
40-
},
41-
}
33+
var (
34+
instance *Registry
35+
once sync.Once
36+
)
37+
38+
// GetRegistry returns the singleton instance of the registry
39+
// This is thread-safe and will only create the instance once
40+
func GetRegistry() *Registry {
41+
once.Do(func() {
42+
instance = &Registry{
43+
providers: []Provider{
44+
huggingface.New(),
45+
modelscope.New(),
46+
// Future providers can be added here:
47+
// civitai.New(),
48+
},
49+
}
50+
})
51+
return instance
52+
}
53+
54+
// ResetRegistry resets the singleton instance
55+
// This should only be used in tests to ensure isolation between test cases
56+
func ResetRegistry() {
57+
once = sync.Once{}
58+
instance = nil
4259
}
4360

4461
// GetProvider returns the appropriate provider for the given model URL
4562
// It iterates through all registered providers and returns the first one
46-
// that supports the URL
63+
// that supports the URL. This only works for full URLs with domain names.
64+
// For short-form URLs (owner/repo), use GetProviderByName with an explicit provider
4765
func (r *Registry) GetProvider(modelURL string) (Provider, error) {
4866
for _, p := range r.providers {
4967
if p.SupportsURL(modelURL) {
5068
return p, nil
5169
}
5270
}
53-
return nil, fmt.Errorf("no provider found for URL: %s", modelURL)
71+
return nil, fmt.Errorf("no provider found for URL: %s. For short-form URLs (owner/repo), use --provider flag to specify the provider explicitly", modelURL)
72+
}
73+
74+
// SelectProvider returns the appropriate provider based on the URL and explicit provider name
75+
// If providerName is specified, it uses GetProviderByName for short-form URLs
76+
// Otherwise, it uses GetProvider for auto-detection with full URLs
77+
func (r *Registry) SelectProvider(modelURL, providerName string) (Provider, error) {
78+
if providerName != "" {
79+
// Explicit provider specified, use it
80+
return r.GetProviderByName(providerName)
81+
}
82+
// No explicit provider, try auto-detection
83+
return r.GetProvider(modelURL)
5484
}
5585

5686
// GetProviderByName returns a specific provider by its name

0 commit comments

Comments
 (0)