Skip to content

Commit 5448f06

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
fix: Require uri or staging bucket configuration for saving model to Vertex Experiment.
PiperOrigin-RevId: 853431474
1 parent 65717fa commit 5448f06

File tree

4 files changed

+25
-18
lines changed

4 files changed

+25
-18
lines changed

google/cloud/aiplatform/metadata/_models.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from typing import Any, Dict, Optional, Sequence, Union
2323

2424
from google.auth import credentials as auth_credentials
25-
from google.cloud import storage
2625
from google.cloud import aiplatform
2726
from google.cloud.aiplatform import base
2827
from google.cloud.aiplatform import explain
@@ -371,6 +370,7 @@ def save_model(
371370
project: Optional[str] = None,
372371
location: Optional[str] = None,
373372
credentials: Optional[auth_credentials.Credentials] = None,
373+
staging_bucket: Optional[str] = None,
374374
) -> google_artifact_schema.ExperimentModel:
375375
"""Saves a ML model into a MLMD artifact.
376376
@@ -418,12 +418,18 @@ def save_model(
418418
credentials (auth_credentials.Credentials):
419419
Optional. Custom credentials used to create this Artifact. Overrides
420420
credentials set in aiplatform.init.
421+
staging_bucket (str):
422+
Optional. The staging bucket used to save the model. If not provided,
423+
the staging bucket set in aiplatform.init will be used. A staging
424+
bucket or uri is required for saving a model.
421425
422426
Returns:
423427
An ExperimentModel instance.
424428
425429
Raises:
426430
ValueError: if model type is not supported.
431+
RuntimeError: If staging bucket was not set using aiplatform.init
432+
and a staging bucket or uri was not passed in.
427433
"""
428434
framework_name = framework_version = ""
429435
try:
@@ -476,24 +482,13 @@ def save_model(
476482
model_file = _FRAMEWORK_SPECS[framework_name]["model_file"]
477483

478484
if not uri:
479-
staging_bucket = initializer.global_config.staging_bucket
480-
# TODO(b/264196887)
485+
staging_bucket = staging_bucket or initializer.global_config.staging_bucket
486+
481487
if not staging_bucket:
482-
project = project or initializer.global_config.project
483-
location = location or initializer.global_config.location
484-
credentials = credentials or initializer.global_config.credentials
485-
486-
staging_bucket_name = project + "-vertex-staging-" + location
487-
client = storage.Client(project=project, credentials=credentials)
488-
staging_bucket = storage.Bucket(client=client, name=staging_bucket_name)
489-
if not staging_bucket.exists():
490-
_LOGGER.info(f'Creating staging bucket "{staging_bucket_name}"')
491-
staging_bucket = client.create_bucket(
492-
bucket_or_name=staging_bucket,
493-
project=project,
494-
location=location,
495-
)
496-
staging_bucket = f"gs://{staging_bucket_name}"
488+
raise RuntimeError(
489+
"staging_bucket should be passed to save_model constructor or "
490+
"should be set using aiplatform.init(staging_bucket='gs://my-bucket')"
491+
)
497492

498493
unique_name = utils.timestamped_unique_name()
499494
uri = f"{staging_bucket}/{unique_name}-{framework_name}-model"

google/cloud/aiplatform/metadata/experiment_run_resource.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,7 @@ def log_model(
11961196
project: Optional[str] = None,
11971197
location: Optional[str] = None,
11981198
credentials: Optional[auth_credentials.Credentials] = None,
1199+
staging_bucket: Optional[str] = None,
11991200
) -> google_artifact_schema.ExperimentModel:
12001201
"""Saves a ML model into a MLMD artifact and log it to this ExperimentRun.
12011202
@@ -1245,12 +1246,18 @@ def log_model(
12451246
credentials (auth_credentials.Credentials):
12461247
Optional. Custom credentials used to create this Artifact. Overrides
12471248
credentials set in aiplatform.init.
1249+
staging_bucket (str):
1250+
Optional. The staging bucket used to save the model. If not provided,
1251+
the staging bucket set in aiplatform.init will be used. A staging
1252+
bucket or uri is required for saving a model.
12481253
12491254
Returns:
12501255
An ExperimentModel instance.
12511256
12521257
Raises:
12531258
ValueError: if model type is not supported.
1259+
RuntimeError: If staging bucket was not set using aiplatform.init
1260+
and a staging bucket or uri was not passed in.
12541261
"""
12551262
experiment_model = _models.save_model(
12561263
model=model,
@@ -1262,6 +1269,7 @@ def log_model(
12621269
project=project,
12631270
location=location,
12641271
credentials=credentials,
1272+
staging_bucket=staging_bucket,
12651273
)
12661274

12671275
self._metadata_node.add_artifacts_and_executions(

samples/model-builder/experiment_tracking/save_model_sample.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def save_model_sample(
3030
Union[list, dict, "pd.DataFrame", "np.ndarray"] # noqa: F821
3131
] = None,
3232
display_name: Optional[str] = None,
33+
staging_bucket: Optional[str] = None,
3334
) -> None:
3435
aiplatform.init(project=project, location=location)
3536

@@ -39,6 +40,7 @@ def save_model_sample(
3940
uri=uri,
4041
input_example=input_example,
4142
display_name=display_name,
43+
staging_bucket=staging_bucket,
4244
)
4345

4446

samples/model-builder/experiment_tracking/save_model_sample_test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def test_save_model_sample(mock_save_model):
2929
uri=constants.MODEL_ARTIFACT_URI,
3030
input_example=constants.EXPERIMENT_MODEL_INPUT_EXAMPLE,
3131
display_name=constants.DISPLAY_NAME,
32+
staging_bucket=constants.STAGING_BUCKET,
3233
)
3334

3435
mock_save_model.assert_called_once_with(
@@ -37,4 +38,5 @@ def test_save_model_sample(mock_save_model):
3738
uri=constants.MODEL_ARTIFACT_URI,
3839
input_example=constants.EXPERIMENT_MODEL_INPUT_EXAMPLE,
3940
display_name=constants.DISPLAY_NAME,
41+
staging_bucket=constants.STAGING_BUCKET,
4042
)

0 commit comments

Comments
 (0)