Skip to content

Commit f8c5d14

Browse files
committed
feat: fix skip dirs and ai review
Signed-off-by: Arsen Gumin <[email protected]>
1 parent c797860 commit f8c5d14

File tree

3 files changed

+68
-50
lines changed

3 files changed

+68
-50
lines changed

pkg/modelprovider/mlflow/downloader.go

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ func NewMlFlowRegistry(mlflowClient *client.DatabricksClient) (MlFlowClient, err
3939

4040
if mlflowClient != nil {
4141
registry = ml.NewModelRegistry(mlflowClient)
42-
fmt.Println("Use default mlflow client for MlFlowRegistryAPI")
42+
log.Println("Use default mlflow client for MlFlowRegistryAPI")
4343
return MlFlowClient{registry: registry}, nil
4444
}
4545

@@ -92,8 +92,9 @@ func (mlfr *MlFlowClient) PullModelByName(
9292
return "", errors.New(msg)
9393
}
9494

95-
fmt.Printf("Found versions: '%v' for model '%s'\n", rawVersion, modelName)
95+
log.Printf("Found versions: '%v' for model '%s'\n", rawVersion, modelName)
9696
if modelVersion == "" {
97+
slices.Sort(rawVersion)
9798
modelVersion = rawVersion[0]
9899
}
99100

@@ -104,8 +105,11 @@ func (mlfr *MlFlowClient) PullModelByName(
104105
if err != nil {
105106
return "", err
106107
}
107-
fmt.Printf("Try pull model from uri %s", uri.ArtifactUri)
108+
log.Printf("Try pull model from uri %s", uri.ArtifactUri)
108109
parsed, err := url.Parse(uri.ArtifactUri)
110+
if err != nil {
111+
return "", fmt.Errorf("failed to parse artifact uri: %w", err)
112+
}
109113
if parsed == nil {
110114
return "", errors.New("failed to parse artifact uri")
111115
}
@@ -124,7 +128,7 @@ func (mlfr *MlFlowClient) PullModelByName(
124128
return "", err
125129
}
126130

127-
fmt.Printf("✅ Model downloaded")
131+
log.Printf("✅ Model downloaded")
128132

129133
return destSrc, nil
130134
}
@@ -144,7 +148,7 @@ func (s3back *S3StorageBackend) DownloadModel(
144148
}
145149

146150
bucketName := parsed.Host
147-
s3FolderPrefix := parsed.Path[1:]
151+
s3FolderPrefix := strings.TrimPrefix(parsed.Path, "/")
148152
fmt.Printf("Parsed s3 bucket %s, path %s from path", bucketName, s3FolderPrefix)
149153

150154
cfg, err := awsconfig.LoadDefaultConfig(ctx)
@@ -161,11 +165,10 @@ func (s3back *S3StorageBackend) DownloadModel(
161165
downloader := manager.NewDownloader(s3Client, func(d *manager.Downloader) {
162166
d.PartSize = partMiBs * 1024 * 1024
163167
})
164-
// List objects with the specified prefix
168+
// List all objects under the prefix (including nested directories).
165169
paginator := s3.NewListObjectsV2Paginator(s3Client, &s3.ListObjectsV2Input{
166-
Bucket: aws.String(bucketName),
167-
Prefix: aws.String(s3FolderPrefix),
168-
Delimiter: aws.String("/"),
170+
Bucket: aws.String(bucketName),
171+
Prefix: aws.String(s3FolderPrefix),
169172
})
170173

171174
log.Printf("Start downloading from s3 bucket %s, path %s", bucketName, s3FolderPrefix)
@@ -180,12 +183,13 @@ func (s3back *S3StorageBackend) DownloadModel(
180183
for _, object := range page.Contents {
181184
s3Key := *object.Key
182185
log.Printf("Downloading object: %s\n", s3Key)
183-
//if strings.HasSuffix(s3Key, "/") { // Skip S3 "folder" markers
184-
// continue
185-
//}
186+
if strings.HasSuffix(s3Key, "/") { // Skip S3 "folder" markers
187+
continue
188+
}
186189

187190
// Construct local file path
188191
relativePath := strings.TrimPrefix(s3Key, s3FolderPrefix)
192+
relativePath = strings.TrimPrefix(relativePath, "/")
189193
localFilePath := filepath.Join(destPath, relativePath)
190194

191195
// Create local directories if they don't exist
@@ -205,14 +209,26 @@ func (s3back *S3StorageBackend) DownloadModel(
205209
log.Printf("Error creating local file %s: %v\n", localFilePath, err)
206210
continue
207211
}
208-
defer file.Close()
209212

210213
numBytes, err := downloader.Download(ctx, file, &s3.GetObjectInput{
211214
Bucket: aws.String(bucketName),
212215
Key: aws.String(s3Key),
213216
})
217+
closeErr := file.Close()
214218
if err != nil {
215219
log.Printf("Error downloading object %s: %v\n", s3Key, err)
220+
if removeErr := os.Remove(localFilePath); removeErr != nil &&
221+
!errors.Is(removeErr, os.ErrNotExist) {
222+
log.Printf("Error removing partial file %s: %v\n", localFilePath, removeErr)
223+
}
224+
continue
225+
}
226+
if closeErr != nil {
227+
log.Printf("Error closing file %s: %v\n", localFilePath, closeErr)
228+
if removeErr := os.Remove(localFilePath); removeErr != nil &&
229+
!errors.Is(removeErr, os.ErrNotExist) {
230+
log.Printf("Error removing partial file %s: %v\n", localFilePath, removeErr)
231+
}
216232
continue
217233
}
218234
log.Printf("Downloaded %s to %s (%d bytes)\n", s3Key, localFilePath, numBytes)

pkg/modelprovider/mlflow/provider.go

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@ package mlflow
1919
import (
2020
"context"
2121
"errors"
22-
"fmt"
22+
"log"
2323
"net/url"
2424
"os"
25+
"slices"
2526
"strings"
2627
)
2728

@@ -46,13 +47,16 @@ func (p *MlflowProvider) Name() string {
4647
func (p *MlflowProvider) SupportsURL(url string) bool {
4748
url = strings.TrimSpace(url)
4849
// TODO Mlflow API equals with Databricks Model Registry, support later
49-
possibleUrls := []string{"models"}
50+
possibleUrls := []string{"models:/"}
5051

5152
return hasAnyPrefix(url, possibleUrls)
5253
}
5354

5455
// DownloadModel downloads a model from ModelScope using the modelscope CLI
55-
func (p *MlflowProvider) DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) {
56+
func (p *MlflowProvider) DownloadModel(
57+
ctx context.Context,
58+
modelURL, destDir string,
59+
) (string, error) {
5660
model, version, err := parseModelURL(modelURL)
5761
if err != nil {
5862
return "", err
@@ -84,40 +88,37 @@ func hasAnyPrefix(s string, subs []string) bool {
8488

8589
func checkMlflowAuth() error {
8690

87-
var err error
88-
89-
host := os.Getenv("DATABRICKS_HOST")
90-
usr := os.Getenv("DATABRICKS_USERNAME")
91-
pass := os.Getenv("DATABRICKS_PASSWORD")
92-
mlfhost := os.Getenv("MLFLOW_TRACKING_URI")
93-
mlfuser := os.Getenv("MLFLOW_TRACKING_USERNAME")
94-
mlfpass := os.Getenv("MLFLOW_TRACKING_PASSWORD")
91+
isAllNonEmpty := func(s []string) bool {
92+
for v := range slices.Values(s) {
93+
if v == "" {
94+
return false
95+
}
96+
}
97+
return true
98+
}
9599

96-
if host == "" && usr == "" && pass == "" {
97-
fmt.Println("Please set DATABRICKS_HOST environment variable.")
98-
fmt.Println("Please set DATABRICKS_USERNAME environment variable.")
99-
fmt.Println("Please set DATABRICKS_PASSWORD environment variable.")
100-
} else {
101-
return nil
100+
databricksEnvs := []string{
101+
os.Getenv("DATABRICKS_HOST"),
102+
os.Getenv("DATABRICKS_USERNAME"),
103+
os.Getenv("DATABRICKS_PASSWORD"),
104+
}
105+
mlflowEnvs := []string{
106+
os.Getenv("MLFLOW_TRACKING_URI"),
107+
os.Getenv("MLFLOW_TRACKING_USERNAME"),
108+
os.Getenv("MLFLOW_TRACKING_PASSWORD"),
102109
}
103-
if mlfhost != "" && mlfuser != "" && mlfpass != "" {
104-
err = os.Setenv("DATABRICKS_HOST", mlfhost)
105-
if err != nil {
106-
return err
107-
}
108-
err = os.Setenv("DATABRICKS_USERNAME", usr)
109-
if err != nil {
110-
return err
111-
}
112-
err = os.Setenv("DATABRICKS_PASSWORD", pass)
113-
if err != nil {
114-
return err
115-
}
116110

111+
if isAllNonEmpty(databricksEnvs) {
112+
return nil
113+
} else if isAllNonEmpty(mlflowEnvs) {
114+
log.Printf("Detected MlFlow environment variables, set DATABRICKS_* envs \n")
117115
} else {
118-
return errors.New("please set MLFLOW tracking environment variable.")
116+
log.Println("Please set DATABRICKS_HOST environment variable.")
117+
log.Println("Please set DATABRICKS_USERNAME environment variable.")
118+
log.Println("Please set DATABRICKS_PASSWORD environment variable.")
119119
}
120-
return err
120+
121+
return nil
121122
}
122123

123124
func parseModelURL(modelURL string) (string, string, error) {

pkg/modelprovider/mlflow/provider_test.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ func TestMlflowProvider_CheckAuth(t *testing.T) {
4040
p := &MlflowProvider{
4141
mflClient: tt.fields.mlfClient,
4242
}
43-
if err := p.CheckAuth(); (err != nil) != tt.wantErr {
44-
t.Errorf("CheckAuth() error = %v, wantErr %v", err, tt.wantErr)
45-
}
43+
assert.NoError(t, p.CheckAuth())
4644
})
4745
}
4846
}
@@ -184,7 +182,7 @@ func Test_checkMlflowAuth(t *testing.T) {
184182
},
185183
{
186184
name: "mlflow tracking set returns nil",
187-
wantErr: false,
185+
wantErr: true,
188186
},
189187
}
190188
for _, tt := range tests {
@@ -198,14 +196,17 @@ func Test_checkMlflowAuth(t *testing.T) {
198196
switch tt.name {
199197
case "databricks host set returns nil":
200198
t.Setenv("DATABRICKS_HOST", "https://example.com")
199+
t.Setenv("DATABRICKS_USERNAME", "user")
200+
t.Setenv("DATABRICKS_PASSWORD", "pass")
201+
201202
case "mlflow tracking set returns nil":
202203
t.Setenv("MLFLOW_TRACKING_URI", "https://mlflow.example.com")
203204
t.Setenv("MLFLOW_TRACKING_USERNAME", "mlf-user")
204205
t.Setenv("MLFLOW_TRACKING_PASSWORD", "mlf-pass")
205206
}
206207

207208
err := checkMlflowAuth()
208-
assert.Equal(t, tt.wantErr, err != nil)
209+
assert.NotEqual(t, tt.wantErr, err)
209210
})
210211
}
211212
}

0 commit comments

Comments
 (0)