@@ -178,125 +178,6 @@ def test_text_generator_predict_with_params_success(
178
178
)
179
179
180
180
181
- def test_create_embedding_generator_model (
182
- palm2_embedding_generator_model , dataset_id , bq_connection
183
- ):
184
- # Model creation doesn't return error
185
- assert palm2_embedding_generator_model is not None
186
- assert palm2_embedding_generator_model ._bqml_model is not None
187
-
188
- # save, load to ensure configuration was kept
189
- reloaded_model = palm2_embedding_generator_model .to_gbq (
190
- f"{ dataset_id } .temp_embedding_model" , replace = True
191
- )
192
- assert f"{ dataset_id } .temp_embedding_model" == reloaded_model ._bqml_model .model_name
193
- assert reloaded_model .model_name == "textembedding-gecko"
194
- assert reloaded_model .connection_name == bq_connection
195
-
196
-
197
- def test_create_embedding_generator_model_002 (
198
- palm2_embedding_generator_model_002 , dataset_id , bq_connection
199
- ):
200
- # Model creation doesn't return error
201
- assert palm2_embedding_generator_model_002 is not None
202
- assert palm2_embedding_generator_model_002 ._bqml_model is not None
203
-
204
- # save, load to ensure configuration was kept
205
- reloaded_model = palm2_embedding_generator_model_002 .to_gbq (
206
- f"{ dataset_id } .temp_embedding_model" , replace = True
207
- )
208
- assert f"{ dataset_id } .temp_embedding_model" == reloaded_model ._bqml_model .model_name
209
- assert reloaded_model .model_name == "textembedding-gecko"
210
- assert reloaded_model .version == "002"
211
- assert reloaded_model .connection_name == bq_connection
212
-
213
-
214
- def test_create_embedding_generator_multilingual_model (
215
- palm2_embedding_generator_multilingual_model ,
216
- dataset_id ,
217
- bq_connection ,
218
- ):
219
- # Model creation doesn't return error
220
- assert palm2_embedding_generator_multilingual_model is not None
221
- assert palm2_embedding_generator_multilingual_model ._bqml_model is not None
222
-
223
- # save, load to ensure configuration was kept
224
- reloaded_model = palm2_embedding_generator_multilingual_model .to_gbq (
225
- f"{ dataset_id } .temp_embedding_model" , replace = True
226
- )
227
- assert f"{ dataset_id } .temp_embedding_model" == reloaded_model ._bqml_model .model_name
228
- assert reloaded_model .model_name == "textembedding-gecko-multilingual"
229
- assert reloaded_model .connection_name == bq_connection
230
-
231
-
232
- def test_create_text_embedding_generator_model_defaults (bq_connection ):
233
- import bigframes .pandas as bpd
234
-
235
- # Note: This starts a thread-local session.
236
- with bpd .option_context (
237
- "bigquery.bq_connection" ,
238
- bq_connection ,
239
- "bigquery.location" ,
240
- "US" ,
241
- ):
242
- model = llm .PaLM2TextEmbeddingGenerator ()
243
- assert model is not None
244
- assert model ._bqml_model is not None
245
-
246
-
247
- def test_create_text_embedding_generator_multilingual_model_defaults (bq_connection ):
248
- import bigframes .pandas as bpd
249
-
250
- # Note: This starts a thread-local session.
251
- with bpd .option_context (
252
- "bigquery.bq_connection" ,
253
- bq_connection ,
254
- "bigquery.location" ,
255
- "US" ,
256
- ):
257
- model = llm .PaLM2TextEmbeddingGenerator (
258
- model_name = "textembedding-gecko-multilingual"
259
- )
260
- assert model is not None
261
- assert model ._bqml_model is not None
262
-
263
-
264
- @pytest .mark .flaky (retries = 2 )
265
- def test_embedding_generator_predict_success (
266
- palm2_embedding_generator_model , llm_text_df
267
- ):
268
- df = palm2_embedding_generator_model .predict (llm_text_df ).to_pandas ()
269
- assert df .shape == (3 , 4 )
270
- assert "text_embedding" in df .columns
271
- series = df ["text_embedding" ]
272
- value = series [0 ]
273
- assert len (value ) == 768
274
-
275
-
276
- @pytest .mark .flaky (retries = 2 )
277
- def test_embedding_generator_multilingual_predict_success (
278
- palm2_embedding_generator_multilingual_model , llm_text_df
279
- ):
280
- df = palm2_embedding_generator_multilingual_model .predict (llm_text_df ).to_pandas ()
281
- assert df .shape == (3 , 4 )
282
- assert "text_embedding" in df .columns
283
- series = df ["text_embedding" ]
284
- value = series [0 ]
285
- assert len (value ) == 768
286
-
287
-
288
- @pytest .mark .flaky (retries = 2 )
289
- def test_embedding_generator_predict_series_success (
290
- palm2_embedding_generator_model , llm_text_df
291
- ):
292
- df = palm2_embedding_generator_model .predict (llm_text_df ["prompt" ]).to_pandas ()
293
- assert df .shape == (3 , 4 )
294
- assert "text_embedding" in df .columns
295
- series = df ["text_embedding" ]
296
- value = series [0 ]
297
- assert len (value ) == 768
298
-
299
-
300
181
@pytest .mark .parametrize (
301
182
"model_name" ,
302
183
("text-embedding-004" , "text-multilingual-embedding-002" ),
0 commit comments