Skip to content

Commit f439f83

Browse files
committed
feat: stable code with tests
Signed-off-by: Arsen Gumin <[email protected]>
1 parent 69246d0 commit f439f83

File tree

5 files changed

+499
-40
lines changed

5 files changed

+499
-40
lines changed

pkg/modelprovider/mlflow/downloader.go

Lines changed: 39 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -43,32 +43,38 @@ func NewMlFlowRegistry(mlflowClient *client.DatabricksClient) (MlFlowClient, err
4343
return MlFlowClient{registry: registry}, nil
4444
}
4545

46-
//TODO Support more auth methods?
46+
// TODO Support more auth methods?
4747
cfg := config.Config{
48-
//Credentials: config.BasicCredentials{},
48+
// Credentials: config.BasicCredentials{},
4949
}
5050
mlClient, err := client.New(&cfg)
51-
5251
if err != nil {
5352
return MlFlowClient{}, err
5453
}
5554
registry = ml.NewModelRegistry(mlClient)
5655
return MlFlowClient{registry: registry}, nil
5756
}
5857

59-
func (mlfr *MlFlowClient) PullModelByName(ctx context.Context, modelName string, modelVersion string, destSrc string) error {
60-
58+
func (mlfr *MlFlowClient) PullModelByName(
59+
ctx context.Context,
60+
modelName string,
61+
modelVersion string,
62+
destSrc string,
63+
) (string, error) {
6164
if mlfr == nil || mlfr.registry == nil {
62-
return errors.New("Mlflow client is not initialized: registry is nil")
65+
return "", errors.New("mlflow client is not initialized: registry is nil")
6366
}
6467

65-
versions, err := mlfr.registry.GetLatestVersionsAll(ctx, ml.GetLatestVersionsRequest{Name: modelName})
68+
versions, err := mlfr.registry.GetLatestVersionsAll(
69+
ctx,
70+
ml.GetLatestVersionsRequest{Name: modelName},
71+
)
6672
if err != nil {
67-
return errors.Join(errors.New(fmt.Sprintf("failed to get versions for model: %s", modelName)), err)
73+
return "", errors.Join(fmt.Errorf("failed to get versions for model: %s", modelName), err)
6874
}
6975

7076
if len(versions) == 0 {
71-
return errors.New(fmt.Sprintf("model %s has versions: %v", modelName, versions))
77+
return "", fmt.Errorf("model %s has versions: %v", modelName, versions)
7278
}
7379

7480
var rawVersion []string
@@ -78,9 +84,12 @@ func (mlfr *MlFlowClient) PullModelByName(ctx context.Context, modelName string,
7884
contains := slices.Contains(rawVersion, modelVersion)
7985
if !contains {
8086
msg := fmt.Sprintf(
81-
"model %s version %s not found, available version %v", modelName, modelVersion, rawVersion,
87+
"model %s version %s not found, available version %v",
88+
modelName,
89+
modelVersion,
90+
rawVersion,
8291
)
83-
return errors.New(msg)
92+
return "", errors.New(msg)
8493
}
8594

8695
fmt.Printf("Found versions: '%v' for model '%s'\n", rawVersion, modelName)
@@ -92,40 +101,43 @@ func (mlfr *MlFlowClient) PullModelByName(ctx context.Context, modelName string,
92101
Name: modelName,
93102
Version: modelVersion,
94103
})
95-
96104
if err != nil {
97-
return err
105+
return "", err
98106
}
99107
fmt.Printf("Try pull model from uri %s", uri.ArtifactUri)
100108
parsed, err := url.Parse(uri.ArtifactUri)
101109
if parsed == nil {
102-
return errors.New("failed to parse artifact uri")
110+
return "", errors.New("failed to parse artifact uri")
103111
}
104112

105113
switch parsed.Scheme {
106114
case "s3":
107115
s3storage := storageProvider[parsed.Scheme]
116+
destSrc = filepath.Join(destSrc, modelName)
108117
err = s3storage.DownloadModel(ctx, uri.ArtifactUri+"/", destSrc) // it's dir
109118
if err != nil {
110-
return err
119+
return "", err
111120
}
112121
default:
113122
msg := fmt.Sprintf("Unsupported artifact storage type: %s", parsed.Scheme)
114123
err = errors.New(msg)
115-
return err
124+
return "", err
116125
}
117126

118127
fmt.Printf("✅ Model downloaded")
119128

120-
return nil
129+
return destSrc, nil
121130
}
122131

123132
type S3StorageBackend struct {
124133
addressing string
125134
}
126135

127-
func (s3back *S3StorageBackend) DownloadModel(ctx context.Context, path string, destPath string) error {
128-
136+
func (s3back *S3StorageBackend) DownloadModel(
137+
ctx context.Context,
138+
path string,
139+
destPath string,
140+
) error {
129141
parsed, err := url.Parse(path)
130142
if err != nil {
131143
return err
@@ -134,12 +146,10 @@ func (s3back *S3StorageBackend) DownloadModel(ctx context.Context, path string,
134146
bucketName := parsed.Host
135147
s3FolderPrefix := parsed.Path[1:]
136148
fmt.Printf("Parsed s3 bucket %s, path %s from path", bucketName, s3FolderPrefix)
137-
if destPath == "" {
138-
destPath = "./downloads/"
139-
}
149+
140150
cfg, err := awsconfig.LoadDefaultConfig(ctx)
141151
if err != nil {
142-
wrap := errors.New(fmt.Sprintf("Error loading AWS config, try change envs or profile: %v\n", err))
152+
wrap := fmt.Errorf("Error loading AWS config, try change envs or profile: %v\n", err)
143153
return errors.Join(wrap, err)
144154
}
145155

@@ -179,9 +189,13 @@ func (s3back *S3StorageBackend) DownloadModel(ctx context.Context, path string,
179189
localFilePath := filepath.Join(destPath, relativePath)
180190

181191
// Create local directories if they don't exist
182-
err = os.MkdirAll(filepath.Dir(localFilePath), 0755)
192+
err = os.MkdirAll(filepath.Dir(localFilePath), 0o755)
183193
if err != nil {
184-
log.Printf("Error creating local directory %s: %v\n", filepath.Dir(localFilePath), err)
194+
log.Printf(
195+
"Error creating local directory %s: %v\n",
196+
filepath.Dir(localFilePath),
197+
err,
198+
)
185199
continue
186200
}
187201

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
package mlflow
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/databricks/databricks-sdk-go/client"
8+
"github.com/databricks/databricks-sdk-go/service/ml"
9+
)
10+
11+
func TestMlFlowClient_PullModelByName(t *testing.T) {
12+
type fields struct {
13+
registry *ml.ModelRegistryAPI
14+
}
15+
type args struct {
16+
ctx context.Context
17+
modelName string
18+
modelVersion string
19+
destSrc string
20+
}
21+
tests := []struct {
22+
name string
23+
fields fields
24+
args args
25+
want string
26+
wantErr bool
27+
}{
28+
{
29+
name: "nil receiver returns error",
30+
fields: fields{registry: nil},
31+
args: args{ctx: context.Background(), modelName: "model", modelVersion: "1", destSrc: "/tmp"},
32+
want: "",
33+
wantErr: true,
34+
},
35+
}
36+
for _, tt := range tests {
37+
t.Run(tt.name, func(t *testing.T) {
38+
mlfr := &MlFlowClient{
39+
registry: tt.fields.registry,
40+
}
41+
got, err := mlfr.PullModelByName(tt.args.ctx, tt.args.modelName, tt.args.modelVersion, tt.args.destSrc)
42+
if (err != nil) != tt.wantErr {
43+
t.Errorf("PullModelByName() error = %v, wantErr %v", err, tt.wantErr)
44+
return
45+
}
46+
if got != tt.want {
47+
t.Errorf("PullModelByName() got = %v, want %v", got, tt.want)
48+
}
49+
})
50+
}
51+
}
52+
53+
func TestNewMlFlowRegistry(t *testing.T) {
54+
type args struct {
55+
mlflowClient *client.DatabricksClient
56+
}
57+
tests := []struct {
58+
name string
59+
args args
60+
want MlFlowClient
61+
wantErr bool
62+
}{
63+
{
64+
name: "non-nil client returns registry",
65+
args: args{mlflowClient: &client.DatabricksClient{}},
66+
want: MlFlowClient{},
67+
wantErr: false,
68+
},
69+
}
70+
for _, tt := range tests {
71+
t.Run(tt.name, func(t *testing.T) {
72+
got, err := NewMlFlowRegistry(tt.args.mlflowClient)
73+
if (err != nil) != tt.wantErr {
74+
t.Errorf("NewMlFlowRegistry() error = %v, wantErr %v", err, tt.wantErr)
75+
return
76+
}
77+
if !tt.wantErr && got.registry == nil {
78+
t.Errorf("NewMlFlowRegistry() registry is nil")
79+
}
80+
})
81+
}
82+
}
83+
84+
func TestS3StorageBackend_DownloadModel(t *testing.T) {
85+
type fields struct {
86+
addressing string
87+
}
88+
type args struct {
89+
ctx context.Context
90+
path string
91+
destPath string
92+
}
93+
tests := []struct {
94+
name string
95+
fields fields
96+
args args
97+
wantErr bool
98+
}{
99+
{
100+
name: "invalid url returns error",
101+
fields: fields{addressing: ""},
102+
args: args{ctx: context.Background(), path: "http://[::1", destPath: "/tmp"},
103+
wantErr: true,
104+
},
105+
}
106+
for _, tt := range tests {
107+
t.Run(tt.name, func(t *testing.T) {
108+
s3back := &S3StorageBackend{
109+
addressing: tt.fields.addressing,
110+
}
111+
if err := s3back.DownloadModel(tt.args.ctx, tt.args.path, tt.args.destPath); (err != nil) != tt.wantErr {
112+
t.Errorf("DownloadModel() error = %v, wantErr %v", err, tt.wantErr)
113+
}
114+
})
115+
}
116+
}

pkg/modelprovider/mlflow/provider.go

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,14 @@ import (
2020
"context"
2121
"errors"
2222
"fmt"
23+
"net/url"
2324
"os"
24-
"os/exec"
25-
"path/filepath"
2625
"strings"
2726
)
2827

2928
// MlflowProvider implements the modelprovider.Provider interface for Mlflow
3029
type MlflowProvider struct {
31-
mlfClient MlFlowClient
30+
mflClient MlFlowClient
3231
}
3332

3433
// New creates a new ModelScope provider instance
@@ -47,16 +46,25 @@ func (p *MlflowProvider) Name() string {
4746
func (p *MlflowProvider) SupportsURL(url string) bool {
4847
url = strings.TrimSpace(url)
4948
// TODO Mlflow API equals with Databricks Model Registry, support later
50-
possibleUrls := []string{"models", "runs"}
49+
possibleUrls := []string{"models"}
5150

5251
return hasAnyPrefix(url, possibleUrls)
5352
}
5453

5554
// DownloadModel downloads a model from ModelScope using the modelscope CLI
5655
func (p *MlflowProvider) DownloadModel(ctx context.Context, modelURL, destDir string) (string, error) {
5756
model, version, err := parseModelURL(modelURL)
58-
xclient := NewMlFlowRegistry(nil)
59-
xclient.PullModelByName(modelURL)
57+
if err != nil {
58+
return "", err
59+
}
60+
registryClient, err := NewMlFlowRegistry(nil)
61+
if err != nil {
62+
return "", err
63+
}
64+
downloadPath, err := registryClient.PullModelByName(ctx, model, version, destDir)
65+
if err != nil {
66+
return "", err
67+
}
6068
return downloadPath, nil
6169
}
6270

@@ -75,6 +83,7 @@ func hasAnyPrefix(s string, subs []string) bool {
7583
}
7684

7785
func checkMlflowAuth() error {
86+
7887
var err error
7988

8089
host := os.Getenv("DATABRICKS_HOST")
@@ -88,7 +97,10 @@ func checkMlflowAuth() error {
8897
fmt.Println("Please set DATABRICKS_HOST environment variable.")
8998
fmt.Println("Please set DATABRICKS_USERNAME environment variable.")
9099
fmt.Println("Please set DATABRICKS_PASSWORD environment variable.")
91-
} else if mlfhost != "" && mlfuser != "" && mlfpass != "" {
100+
} else {
101+
return nil
102+
}
103+
if mlfhost != "" && mlfuser != "" && mlfpass != "" {
92104
err = os.Setenv("DATABRICKS_HOST", mlfhost)
93105
if err != nil {
94106
return err
@@ -108,12 +120,33 @@ func checkMlflowAuth() error {
108120
return err
109121
}
110122

111-
func parseModelURL(modelURL string) (name, version, error) {
112-
var err error
113-
var name string
114-
var version string
123+
func parseModelURL(modelURL string) (string, string, error) {
124+
if modelURL == "" {
125+
return "", "", errors.New("modelUrl value missing.")
126+
}
127+
128+
if strings.HasPrefix(modelURL, "models:") {
129+
parse, err := url.Parse(modelURL)
130+
if err != nil {
131+
return "", "", err
132+
}
133+
134+
if parse == nil {
135+
return "", "", errors.New("model url is nil")
136+
}
137+
138+
return parse.Hostname(), strings.TrimLeft(parse.Path, "/"), nil
139+
140+
} else if strings.Contains(modelURL, "/") {
115141

116-
name = ""
117-
version = ""
118-
return
142+
split := strings.Split(modelURL, "/")
143+
144+
if len(split) != 2 {
145+
return "", "", errors.New("model url is invalid, valid mask name/version")
146+
}
147+
return split[0], split[1], nil
148+
149+
} else {
150+
return modelURL, "", nil
151+
}
119152
}

0 commit comments

Comments
 (0)