Skip to content

Commit c5eac01

Browse files
authored
chore: fix experimental blob to create connections if not exist (#1334)
1 parent d8ab772 commit c5eac01

File tree

4 files changed

+44
-114
lines changed

4 files changed

+44
-114
lines changed

bigframes/ml/llm.py

Lines changed: 13 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -643,37 +643,16 @@ def __init__(
643643
):
644644
self.model_name = model_name
645645
self.session = session or global_session.get_global_session()
646-
self._bq_connection_manager = self.session.bqconnectionmanager
647-
648-
connection_name = connection_name or self.session._bq_connection
649-
self.connection_name = clients.resolve_full_bq_connection_name(
650-
connection_name,
651-
default_project=self.session._project,
652-
default_location=self.session._location,
653-
)
646+
self.connection_name = connection_name
654647

655648
self._bqml_model_factory = globals.bqml_model_factory()
656649
self._bqml_model: core.BqmlModel = self._create_bqml_model()
657650

658651
def _create_bqml_model(self):
659652
# Parse and create connection if needed.
660-
if not self.connection_name:
661-
raise ValueError(
662-
"Must provide connection_name, either in constructor or through session options."
663-
)
664-
665-
if self._bq_connection_manager:
666-
connection_name_parts = self.connection_name.split(".")
667-
if len(connection_name_parts) != 3:
668-
raise ValueError(
669-
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
670-
)
671-
self._bq_connection_manager.create_bq_connection(
672-
project_id=connection_name_parts[0],
673-
location=connection_name_parts[1],
674-
connection_id=connection_name_parts[2],
675-
iam_role="aiplatform.user",
676-
)
653+
self.connection_name = self.session._create_bq_connection(
654+
connection=self.connection_name, iam_role="aiplatform.user"
655+
)
677656

678657
if self.model_name not in _TEXT_EMBEDDING_ENDPOINTS:
679658
msg = _MODEL_NOT_SUPPORTED_WARNING.format(
@@ -828,37 +807,16 @@ def __init__(
828807
self.model_name = model_name
829808
self.session = session or global_session.get_global_session()
830809
self.max_iterations = max_iterations
831-
self._bq_connection_manager = self.session.bqconnectionmanager
832-
833-
connection_name = connection_name or self.session._bq_connection
834-
self.connection_name = clients.resolve_full_bq_connection_name(
835-
connection_name,
836-
default_project=self.session._project,
837-
default_location=self.session._location,
838-
)
810+
self.connection_name = connection_name
839811

840812
self._bqml_model_factory = globals.bqml_model_factory()
841813
self._bqml_model: core.BqmlModel = self._create_bqml_model()
842814

843815
def _create_bqml_model(self):
844816
# Parse and create connection if needed.
845-
if not self.connection_name:
846-
raise ValueError(
847-
"Must provide connection_name, either in constructor or through session options."
848-
)
849-
850-
if self._bq_connection_manager:
851-
connection_name_parts = self.connection_name.split(".")
852-
if len(connection_name_parts) != 3:
853-
raise ValueError(
854-
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
855-
)
856-
self._bq_connection_manager.create_bq_connection(
857-
project_id=connection_name_parts[0],
858-
location=connection_name_parts[1],
859-
connection_id=connection_name_parts[2],
860-
iam_role="aiplatform.user",
861-
)
817+
self.connection_name = self.session._create_bq_connection(
818+
connection=self.connection_name, iam_role="aiplatform.user"
819+
)
862820

863821
if self.model_name not in _GEMINI_ENDPOINTS:
864822
msg = _MODEL_NOT_SUPPORTED_WARNING.format(
@@ -953,10 +911,7 @@ def fit(
953911
options["prompt_col"] = X.columns.tolist()[0]
954912

955913
self._bqml_model = self._bqml_model_factory.create_llm_remote_model(
956-
X,
957-
y,
958-
options=options,
959-
connection_name=self.connection_name,
914+
X, y, options=options, connection_name=cast(str, self.connection_name)
960915
)
961916
return self
962917

@@ -1179,37 +1134,16 @@ def __init__(
11791134
):
11801135
self.model_name = model_name
11811136
self.session = session or global_session.get_global_session()
1182-
self._bq_connection_manager = self.session.bqconnectionmanager
1183-
1184-
connection_name = connection_name or self.session._bq_connection
1185-
self.connection_name = clients.resolve_full_bq_connection_name(
1186-
connection_name,
1187-
default_project=self.session._project,
1188-
default_location=self.session._location,
1189-
)
1137+
self.connection_name = connection_name
11901138

11911139
self._bqml_model_factory = globals.bqml_model_factory()
11921140
self._bqml_model: core.BqmlModel = self._create_bqml_model()
11931141

11941142
def _create_bqml_model(self):
11951143
# Parse and create connection if needed.
1196-
if not self.connection_name:
1197-
raise ValueError(
1198-
"Must provide connection_name, either in constructor or through session options."
1199-
)
1200-
1201-
if self._bq_connection_manager:
1202-
connection_name_parts = self.connection_name.split(".")
1203-
if len(connection_name_parts) != 3:
1204-
raise ValueError(
1205-
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
1206-
)
1207-
self._bq_connection_manager.create_bq_connection(
1208-
project_id=connection_name_parts[0],
1209-
location=connection_name_parts[1],
1210-
connection_id=connection_name_parts[2],
1211-
iam_role="aiplatform.user",
1212-
)
1144+
self.connection_name = self.session._create_bq_connection(
1145+
connection=self.connection_name, iam_role="aiplatform.user"
1146+
)
12131147

12141148
if self.model_name not in _CLAUDE_3_ENDPOINTS:
12151149
msg = _MODEL_NOT_SUPPORTED_WARNING.format(

bigframes/ml/remote.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from typing import Mapping, Optional
2020
import warnings
2121

22-
from bigframes import clients
2322
from bigframes.core import global_session, log_adapter
2423
import bigframes.dataframe
2524
from bigframes.ml import base, core, globals, utils
@@ -63,35 +62,16 @@ def __init__(
6362
self.session = session or global_session.get_global_session()
6463

6564
self._bq_connection_manager = self.session.bqconnectionmanager
66-
connection_name = connection_name or self.session._bq_connection
67-
self.connection_name = clients.resolve_full_bq_connection_name(
68-
connection_name,
69-
default_project=self.session._project,
70-
default_location=self.session._location,
71-
)
65+
self.connection_name = connection_name
7266

7367
self._bqml_model_factory = globals.bqml_model_factory()
7468
self._bqml_model: core.BqmlModel = self._create_bqml_model()
7569

7670
def _create_bqml_model(self):
7771
# Parse and create connection if needed.
78-
if not self.connection_name:
79-
raise ValueError(
80-
"Must provide connection_name, either in constructor or through session options."
81-
)
82-
83-
if self._bq_connection_manager:
84-
connection_name_parts = self.connection_name.split(".")
85-
if len(connection_name_parts) != 3:
86-
raise ValueError(
87-
f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got {self.connection_name}."
88-
)
89-
self._bq_connection_manager.create_bq_connection(
90-
project_id=connection_name_parts[0],
91-
location=connection_name_parts[1],
92-
connection_id=connection_name_parts[2],
93-
iam_role="aiplatform.user",
94-
)
72+
self.connection_name = self.session._create_bq_connection(
73+
connection=self.connection_name, iam_role="aiplatform.user"
74+
)
9575

9676
options = {
9777
"endpoint": self.endpoint,

bigframes/operations/strings.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import bigframes_vendored.constants as constants
2121
import bigframes_vendored.pandas.core.strings.accessor as vendorstr
2222

23-
from bigframes import clients
2423
from bigframes.core import log_adapter
2524
import bigframes.dataframe as df
2625
import bigframes.operations as ops
@@ -306,11 +305,8 @@ def to_blob(self, connection: Optional[str] = None) -> series.Series:
306305
raise NotImplementedError()
307306

308307
session = self._block.session
309-
connection = connection or session._bq_connection
310-
connection = clients.resolve_full_bq_connection_name(
311-
connection,
312-
default_project=session._project,
313-
default_location=session._location,
308+
connection = session._create_bq_connection(
309+
connection=connection, iam_role="storage.objectUser"
314310
)
315311
return self._apply_binary_op(connection, ops.obj_make_ref_op)
316312

bigframes/session/__init__.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,18 +1647,38 @@ def from_glob_path(
16471647
raise NotImplementedError()
16481648

16491649
# TODO(garrettwu): switch to pseudocolumn when b/374988109 is done.
1650-
connection = connection or self._bq_connection
1651-
connection = bigframes.clients.resolve_full_bq_connection_name(
1652-
connection,
1653-
default_project=self._project,
1654-
default_location=self._location,
1650+
connection = self._create_bq_connection(
1651+
connection=connection, iam_role="storage.objectUser"
16551652
)
16561653

16571654
table = self._create_object_table(path, connection)
16581655

16591656
s = self.read_gbq(table)["uri"].str.to_blob(connection)
16601657
return s.rename(name).to_frame()
16611658

1659+
def _create_bq_connection(
1660+
self, iam_role: str, *, connection: Optional[str] = None
1661+
) -> str:
1662+
"""Create the connection with the session settings and try to attach iam role to the connection SA.
1663+
If any of project, location or connection isn't specified, use the session defaults. Returns fully-qualified connection name."""
1664+
connection = self._bq_connection if not connection else connection
1665+
connection = bigframes.clients.resolve_full_bq_connection_name(
1666+
connection_name=connection,
1667+
default_project=self._project,
1668+
default_location=self._location,
1669+
)
1670+
connection_parts = connection.split(".")
1671+
assert len(connection_parts) == 3
1672+
1673+
self.bqconnectionmanager.create_bq_connection(
1674+
project_id=connection_parts[0],
1675+
location=connection_parts[1],
1676+
connection_id=connection_parts[2],
1677+
iam_role=iam_role,
1678+
)
1679+
1680+
return connection
1681+
16621682
def read_gbq_object_table(
16631683
self, object_table: str, *, name: Optional[str] = None
16641684
) -> dataframe.DataFrame:

0 commit comments

Comments
 (0)