@@ -39,7 +39,7 @@ func NewMlFlowRegistry(mlflowClient *client.DatabricksClient) (MlFlowClient, err
3939
4040 if mlflowClient != nil {
4141 registry = ml .NewModelRegistry (mlflowClient )
42- fmt .Println ("Use default mlflow client for MlFlowRegistryAPI" )
42+ log .Println ("Use default mlflow client for MlFlowRegistryAPI" )
4343 return MlFlowClient {registry : registry }, nil
4444 }
4545
@@ -92,8 +92,9 @@ func (mlfr *MlFlowClient) PullModelByName(
9292 return "" , errors .New (msg )
9393 }
9494
95- fmt .Printf ("Found versions: '%v' for model '%s'\n " , rawVersion , modelName )
95+ log .Printf ("Found versions: '%v' for model '%s'\n " , rawVersion , modelName )
9696 if modelVersion == "" {
97+ slices .Sort (rawVersion )
9798 modelVersion = rawVersion [0 ]
9899 }
99100
@@ -104,8 +105,11 @@ func (mlfr *MlFlowClient) PullModelByName(
104105 if err != nil {
105106 return "" , err
106107 }
107- fmt .Printf ("Try pull model from uri %s" , uri .ArtifactUri )
108+ log .Printf ("Try pull model from uri %s" , uri .ArtifactUri )
108109 parsed , err := url .Parse (uri .ArtifactUri )
110+ if err != nil {
111+ return "" , fmt .Errorf ("failed to parse artifact uri: %w" , err )
112+ }
109113 if parsed == nil {
110114 return "" , errors .New ("failed to parse artifact uri" )
111115 }
@@ -124,7 +128,7 @@ func (mlfr *MlFlowClient) PullModelByName(
124128 return "" , err
125129 }
126130
127- fmt .Printf ("✅ Model downloaded" )
131+ log .Printf ("✅ Model downloaded" )
128132
129133 return destSrc , nil
130134}
@@ -144,7 +148,7 @@ func (s3back *S3StorageBackend) DownloadModel(
144148 }
145149
146150 bucketName := parsed .Host
147- s3FolderPrefix := parsed .Path [ 1 :]
151+ s3FolderPrefix := strings . TrimPrefix ( parsed .Path , "/" )
148152 fmt .Printf ("Parsed s3 bucket %s, path %s from path" , bucketName , s3FolderPrefix )
149153
150154 cfg , err := awsconfig .LoadDefaultConfig (ctx )
@@ -161,11 +165,10 @@ func (s3back *S3StorageBackend) DownloadModel(
161165 downloader := manager .NewDownloader (s3Client , func (d * manager.Downloader ) {
162166 d .PartSize = partMiBs * 1024 * 1024
163167 })
164- // List objects with the specified prefix
168+ // List all objects under the prefix (including nested directories).
165169 paginator := s3 .NewListObjectsV2Paginator (s3Client , & s3.ListObjectsV2Input {
166- Bucket : aws .String (bucketName ),
167- Prefix : aws .String (s3FolderPrefix ),
168- Delimiter : aws .String ("/" ),
170+ Bucket : aws .String (bucketName ),
171+ Prefix : aws .String (s3FolderPrefix ),
169172 })
170173
171174 log .Printf ("Start downloading from s3 bucket %s, path %s" , bucketName , s3FolderPrefix )
@@ -180,12 +183,13 @@ func (s3back *S3StorageBackend) DownloadModel(
180183 for _ , object := range page .Contents {
181184 s3Key := * object .Key
182185 log .Printf ("Downloading object: %s\n " , s3Key )
183- // if strings.HasSuffix(s3Key, "/") { // Skip S3 "folder" markers
184- // continue
185- // }
186+ if strings .HasSuffix (s3Key , "/" ) { // Skip S3 "folder" markers
187+ continue
188+ }
186189
187190 // Construct local file path
188191 relativePath := strings .TrimPrefix (s3Key , s3FolderPrefix )
192+ relativePath = strings .TrimPrefix (relativePath , "/" )
189193 localFilePath := filepath .Join (destPath , relativePath )
190194
191195 // Create local directories if they don't exist
@@ -205,14 +209,26 @@ func (s3back *S3StorageBackend) DownloadModel(
205209 log .Printf ("Error creating local file %s: %v\n " , localFilePath , err )
206210 continue
207211 }
208- defer file .Close ()
209212
210213 numBytes , err := downloader .Download (ctx , file , & s3.GetObjectInput {
211214 Bucket : aws .String (bucketName ),
212215 Key : aws .String (s3Key ),
213216 })
217+ closeErr := file .Close ()
214218 if err != nil {
215219 log .Printf ("Error downloading object %s: %v\n " , s3Key , err )
220+ if removeErr := os .Remove (localFilePath ); removeErr != nil &&
221+ ! errors .Is (removeErr , os .ErrNotExist ) {
222+ log .Printf ("Error removing partial file %s: %v\n " , localFilePath , removeErr )
223+ }
224+ continue
225+ }
226+ if closeErr != nil {
227+ log .Printf ("Error closing file %s: %v\n " , localFilePath , closeErr )
228+ if removeErr := os .Remove (localFilePath ); removeErr != nil &&
229+ ! errors .Is (removeErr , os .ErrNotExist ) {
230+ log .Printf ("Error removing partial file %s: %v\n " , localFilePath , removeErr )
231+ }
216232 continue
217233 }
218234 log .Printf ("Downloaded %s to %s (%d bytes)\n " , s3Key , localFilePath , numBytes )
0 commit comments