Skip to content

Commit ee3dedd

Browse files
optimise as per gemini's review
Signed-off-by: Avinash Singh <[email protected]>
1 parent fb50643 commit ee3dedd

File tree

3 files changed

+12
-24
lines changed

3 files changed

+12
-24
lines changed

cmd/modelfile/generate.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"context"
2121
"fmt"
2222
"os"
23-
"path/filepath"
2423

2524
"github.com/spf13/cobra"
2625
"github.com/spf13/viper"
@@ -66,8 +65,11 @@ Alternatively, use --model_url to download a model from Hugging Face Hub.`,
6665
}
6766

6867
// Validate that either path or model_url is provided
68+
if generateConfig.ModelURL != "" && len(args) > 0 {
69+
return fmt.Errorf("the <path> argument and the --model_url flag are mutually exclusive")
70+
}
6971
if generateConfig.ModelURL == "" && len(args) == 0 {
70-
return fmt.Errorf("either <path> argument or --model_url flag must be provided")
72+
return fmt.Errorf("either a <path> argument or the --model_url flag must be provided")
7173
}
7274

7375
if err := generateConfig.Convert(workspace); err != nil {
@@ -118,10 +120,12 @@ func runGenerate(ctx context.Context) error {
118120
}
119121

120122
// Create a temporary directory for downloading the model
121-
tmpDir := filepath.Join(os.TempDir(), "modctl-hf-downloads")
122-
if err := os.MkdirAll(tmpDir, 0755); err != nil {
123+
// Clean up the temporary directory after the function returns
124+
tmpDir, err := os.MkdirTemp("", "modctl-hf-downloads-*")
125+
if err != nil {
123126
return fmt.Errorf("failed to create temporary directory: %w", err)
124127
}
128+
defer os.RemoveAll(tmpDir)
125129

126130
// Download the model
127131
downloadPath, err := hfhub.DownloadModel(ctx, generateConfig.ModelURL, tmpDir)

pkg/hfhub/download.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,14 +103,10 @@ func DownloadModel(ctx context.Context, modelURL, destDir string) (string, error
103103
cmd.Stdout = os.Stdout
104104
cmd.Stderr = os.Stderr
105105

106-
fmt.Printf("Downloading model %s to %s...\n", repoID, downloadPath)
107-
108106
if err := cmd.Run(); err != nil {
109107
return "", fmt.Errorf("failed to download model using huggingface-cli: %w", err)
110108
}
111109

112-
fmt.Printf("Successfully downloaded model to %s\n", downloadPath)
113-
114110
return downloadPath, nil
115111
}
116112

@@ -136,6 +132,8 @@ func CheckHuggingFaceAuth() error {
136132
// Try using whoami command
137133
if _, err := exec.LookPath("huggingface-cli"); err == nil {
138134
cmd := exec.Command("huggingface-cli", "whoami")
135+
cmd.Stdout = io.Discard
136+
cmd.Stderr = io.Discard
139137
if err := cmd.Run(); err == nil {
140138
return nil
141139
}

pkg/hfhub/download_test.go

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package hfhub
1818

1919
import (
20+
"strings"
2021
"testing"
2122
)
2223

@@ -93,7 +94,7 @@ func TestParseModelURL(t *testing.T) {
9394
t.Errorf("ParseModelURL() expected error but got nil")
9495
return
9596
}
96-
if tt.errContains != "" && err.Error() != tt.errContains && !contains(err.Error(), tt.errContains) {
97+
if tt.errContains != "" && !strings.Contains(err.Error(), tt.errContains) {
9798
t.Errorf("ParseModelURL() error = %v, want error containing %v", err, tt.errContains)
9899
}
99100
return
@@ -114,18 +115,3 @@ func TestParseModelURL(t *testing.T) {
114115
})
115116
}
116117
}
117-
118-
func contains(s, substr string) bool {
119-
return len(s) >= len(substr) && (s == substr || len(s) > len(substr) &&
120-
(s[:len(substr)] == substr || s[len(s)-len(substr):] == substr ||
121-
findInString(s, substr)))
122-
}
123-
124-
func findInString(s, substr string) bool {
125-
for i := 0; i <= len(s)-len(substr); i++ {
126-
if s[i:i+len(substr)] == substr {
127-
return true
128-
}
129-
}
130-
return false
131-
}

0 commit comments

Comments
 (0)