@@ -643,37 +643,16 @@ def __init__(
643
643
):
644
644
self .model_name = model_name
645
645
self .session = session or global_session .get_global_session ()
646
- self ._bq_connection_manager = self .session .bqconnectionmanager
647
-
648
- connection_name = connection_name or self .session ._bq_connection
649
- self .connection_name = clients .resolve_full_bq_connection_name (
650
- connection_name ,
651
- default_project = self .session ._project ,
652
- default_location = self .session ._location ,
653
- )
646
+ self .connection_name = connection_name
654
647
655
648
self ._bqml_model_factory = globals .bqml_model_factory ()
656
649
self ._bqml_model : core .BqmlModel = self ._create_bqml_model ()
657
650
658
651
def _create_bqml_model (self ):
659
652
# Parse and create connection if needed.
660
- if not self .connection_name :
661
- raise ValueError (
662
- "Must provide connection_name, either in constructor or through session options."
663
- )
664
-
665
- if self ._bq_connection_manager :
666
- connection_name_parts = self .connection_name .split ("." )
667
- if len (connection_name_parts ) != 3 :
668
- raise ValueError (
669
- f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got { self .connection_name } ."
670
- )
671
- self ._bq_connection_manager .create_bq_connection (
672
- project_id = connection_name_parts [0 ],
673
- location = connection_name_parts [1 ],
674
- connection_id = connection_name_parts [2 ],
675
- iam_role = "aiplatform.user" ,
676
- )
653
+ self .connection_name = self .session ._create_bq_connection (
654
+ connection = self .connection_name , iam_role = "aiplatform.user"
655
+ )
677
656
678
657
if self .model_name not in _TEXT_EMBEDDING_ENDPOINTS :
679
658
msg = _MODEL_NOT_SUPPORTED_WARNING .format (
@@ -828,37 +807,16 @@ def __init__(
828
807
self .model_name = model_name
829
808
self .session = session or global_session .get_global_session ()
830
809
self .max_iterations = max_iterations
831
- self ._bq_connection_manager = self .session .bqconnectionmanager
832
-
833
- connection_name = connection_name or self .session ._bq_connection
834
- self .connection_name = clients .resolve_full_bq_connection_name (
835
- connection_name ,
836
- default_project = self .session ._project ,
837
- default_location = self .session ._location ,
838
- )
810
+ self .connection_name = connection_name
839
811
840
812
self ._bqml_model_factory = globals .bqml_model_factory ()
841
813
self ._bqml_model : core .BqmlModel = self ._create_bqml_model ()
842
814
843
815
def _create_bqml_model (self ):
844
816
# Parse and create connection if needed.
845
- if not self .connection_name :
846
- raise ValueError (
847
- "Must provide connection_name, either in constructor or through session options."
848
- )
849
-
850
- if self ._bq_connection_manager :
851
- connection_name_parts = self .connection_name .split ("." )
852
- if len (connection_name_parts ) != 3 :
853
- raise ValueError (
854
- f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got { self .connection_name } ."
855
- )
856
- self ._bq_connection_manager .create_bq_connection (
857
- project_id = connection_name_parts [0 ],
858
- location = connection_name_parts [1 ],
859
- connection_id = connection_name_parts [2 ],
860
- iam_role = "aiplatform.user" ,
861
- )
817
+ self .connection_name = self .session ._create_bq_connection (
818
+ connection = self .connection_name , iam_role = "aiplatform.user"
819
+ )
862
820
863
821
if self .model_name not in _GEMINI_ENDPOINTS :
864
822
msg = _MODEL_NOT_SUPPORTED_WARNING .format (
@@ -953,10 +911,7 @@ def fit(
953
911
options ["prompt_col" ] = X .columns .tolist ()[0 ]
954
912
955
913
self ._bqml_model = self ._bqml_model_factory .create_llm_remote_model (
956
- X ,
957
- y ,
958
- options = options ,
959
- connection_name = self .connection_name ,
914
+ X , y , options = options , connection_name = cast (str , self .connection_name )
960
915
)
961
916
return self
962
917
@@ -1179,37 +1134,16 @@ def __init__(
1179
1134
):
1180
1135
self .model_name = model_name
1181
1136
self .session = session or global_session .get_global_session ()
1182
- self ._bq_connection_manager = self .session .bqconnectionmanager
1183
-
1184
- connection_name = connection_name or self .session ._bq_connection
1185
- self .connection_name = clients .resolve_full_bq_connection_name (
1186
- connection_name ,
1187
- default_project = self .session ._project ,
1188
- default_location = self .session ._location ,
1189
- )
1137
+ self .connection_name = connection_name
1190
1138
1191
1139
self ._bqml_model_factory = globals .bqml_model_factory ()
1192
1140
self ._bqml_model : core .BqmlModel = self ._create_bqml_model ()
1193
1141
1194
1142
def _create_bqml_model (self ):
1195
1143
# Parse and create connection if needed.
1196
- if not self .connection_name :
1197
- raise ValueError (
1198
- "Must provide connection_name, either in constructor or through session options."
1199
- )
1200
-
1201
- if self ._bq_connection_manager :
1202
- connection_name_parts = self .connection_name .split ("." )
1203
- if len (connection_name_parts ) != 3 :
1204
- raise ValueError (
1205
- f"connection_name must be of the format <PROJECT_NUMBER/PROJECT_ID>.<LOCATION>.<CONNECTION_ID>, got { self .connection_name } ."
1206
- )
1207
- self ._bq_connection_manager .create_bq_connection (
1208
- project_id = connection_name_parts [0 ],
1209
- location = connection_name_parts [1 ],
1210
- connection_id = connection_name_parts [2 ],
1211
- iam_role = "aiplatform.user" ,
1212
- )
1144
+ self .connection_name = self .session ._create_bq_connection (
1145
+ connection = self .connection_name , iam_role = "aiplatform.user"
1146
+ )
1213
1147
1214
1148
if self .model_name not in _CLAUDE_3_ENDPOINTS :
1215
1149
msg = _MODEL_NOT_SUPPORTED_WARNING .format (
0 commit comments