Skip to content

Commit 89ab33e

Browse files
authored
docs: deprecate default model in TextEmbedddingGenerator, GeminiTextGenerator, and other bigframes.ml.llm classes (#1570)
* docs: add remove default model warning * remove unnecessary warnings * add a new warning * update warning methods * change the default model to None, when None is provide, change default model and raise warning * add a testcase to test warning message
1 parent 1ba72ea commit 89ab33e

File tree

2 files changed

+71
-24
lines changed

2 files changed

+71
-24
lines changed

bigframes/ml/llm.py

Lines changed: 55 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@
103103
"You should use this model name only if you are sure that it is supported in BigQuery."
104104
)
105105

106+
_REMOVE_DEFAULT_MODEL_WARNING = "Since upgrading the default model can cause unintended breakages, the default model will be removed in BigFrames 3.0. Please supply an explicit model to avoid this message."
107+
106108

107109
@log_adapter.class_logger
108110
class TextEmbeddingGenerator(base.RetriableRemotePredictor):
@@ -113,7 +115,8 @@ class TextEmbeddingGenerator(base.RetriableRemotePredictor):
113115
The model for text embedding. Possible values are "text-embedding-005", "text-embedding-004"
114116
or "text-multilingual-embedding-002". text-embedding models returns model embeddings for text inputs.
115117
text-multilingual-embedding models returns model embeddings for text inputs which support over 100 languages.
116-
Default to "text-embedding-004".
118+
If no setting is provided, "text-embedding-004" will be used by
119+
default and a warning will be issued.
117120
session (bigframes.Session or None):
118121
BQ session to create the model. If None, use the global default session.
119122
connection_name (str or None):
@@ -124,14 +127,20 @@ class TextEmbeddingGenerator(base.RetriableRemotePredictor):
124127
def __init__(
125128
self,
126129
*,
127-
model_name: Literal[
128-
"text-embedding-005",
129-
"text-embedding-004",
130-
"text-multilingual-embedding-002",
131-
] = "text-embedding-004",
130+
model_name: Optional[
131+
Literal[
132+
"text-embedding-005",
133+
"text-embedding-004",
134+
"text-multilingual-embedding-002",
135+
]
136+
] = None,
132137
session: Optional[bigframes.Session] = None,
133138
connection_name: Optional[str] = None,
134139
):
140+
if model_name is None:
141+
model_name = "text-embedding-004"
142+
msg = exceptions.format_message(_REMOVE_DEFAULT_MODEL_WARNING)
143+
warnings.warn(msg, category=FutureWarning, stacklevel=2)
135144
self.model_name = model_name
136145
self.session = session or global_session.get_global_session()
137146
self.connection_name = connection_name
@@ -256,7 +265,8 @@ class MultimodalEmbeddingGenerator(base.RetriableRemotePredictor):
256265
Args:
257266
model_name (str, Default to "multimodalembedding@001"):
258267
The model for multimodal embedding. Can set to "multimodalembedding@001". Multimodal-embedding models returns model embeddings for text, image and video inputs.
259-
Default to "multimodalembedding@001".
268+
If no setting is provided, "multimodalembedding@001" will be used by
269+
default and a warning will be issued.
260270
session (bigframes.Session or None):
261271
BQ session to create the model. If None, use the global default session.
262272
connection_name (str or None):
@@ -267,12 +277,16 @@ class MultimodalEmbeddingGenerator(base.RetriableRemotePredictor):
267277
def __init__(
268278
self,
269279
*,
270-
model_name: Literal["multimodalembedding@001"] = "multimodalembedding@001",
280+
model_name: Optional[Literal["multimodalembedding@001"]] = None,
271281
session: Optional[bigframes.Session] = None,
272282
connection_name: Optional[str] = None,
273283
):
274284
if not bigframes.options.experiments.blob:
275285
raise NotImplementedError()
286+
if model_name is None:
287+
model_name = "multimodalembedding@001"
288+
msg = exceptions.format_message(_REMOVE_DEFAULT_MODEL_WARNING)
289+
warnings.warn(msg, category=FutureWarning, stacklevel=2)
276290
self.model_name = model_name
277291
self.session = session or global_session.get_global_session()
278292
self.connection_name = connection_name
@@ -408,7 +422,8 @@ class GeminiTextGenerator(base.RetriableRemotePredictor):
408422
"gemini-1.5-pro-001", "gemini-1.5-pro-002", "gemini-1.5-flash-001",
409423
"gemini-1.5-flash-002", "gemini-2.0-flash-exp",
410424
"gemini-2.0-flash-lite-001", and "gemini-2.0-flash-001".
411-
Default to "gemini-2.0-flash-001".
425+
If no setting is provided, "gemini-2.0-flash-001" will be used by
426+
default and a warning will be issued.
412427
413428
.. note::
414429
"gemini-2.0-flash-exp", "gemini-1.5-pro-preview-0514" and "gemini-1.5-flash-preview-0514" is subject to the "Pre-GA Offerings Terms" in the General Service Terms section of the
@@ -429,17 +444,19 @@ class GeminiTextGenerator(base.RetriableRemotePredictor):
429444
def __init__(
430445
self,
431446
*,
432-
model_name: Literal[
433-
"gemini-1.5-pro-preview-0514",
434-
"gemini-1.5-flash-preview-0514",
435-
"gemini-1.5-pro-001",
436-
"gemini-1.5-pro-002",
437-
"gemini-1.5-flash-001",
438-
"gemini-1.5-flash-002",
439-
"gemini-2.0-flash-exp",
440-
"gemini-2.0-flash-001",
441-
"gemini-2.0-flash-lite-001",
442-
] = "gemini-2.0-flash-001",
447+
model_name: Optional[
448+
Literal[
449+
"gemini-1.5-pro-preview-0514",
450+
"gemini-1.5-flash-preview-0514",
451+
"gemini-1.5-pro-001",
452+
"gemini-1.5-pro-002",
453+
"gemini-1.5-flash-001",
454+
"gemini-1.5-flash-002",
455+
"gemini-2.0-flash-exp",
456+
"gemini-2.0-flash-001",
457+
"gemini-2.0-flash-lite-001",
458+
]
459+
] = None,
443460
session: Optional[bigframes.Session] = None,
444461
connection_name: Optional[str] = None,
445462
max_iterations: int = 300,
@@ -454,6 +471,10 @@ def __init__(
454471
"(https://cloud.google.com/products#product-launch-stages)."
455472
)
456473
warnings.warn(msg, category=exceptions.PreviewWarning)
474+
if model_name is None:
475+
model_name = "gemini-2.0-flash-001"
476+
msg = exceptions.format_message(_REMOVE_DEFAULT_MODEL_WARNING)
477+
warnings.warn(msg, category=FutureWarning, stacklevel=2)
457478
self.model_name = model_name
458479
self.session = session or global_session.get_global_session()
459480
self.max_iterations = max_iterations
@@ -803,7 +824,8 @@ class Claude3TextGenerator(base.RetriableRemotePredictor):
803824
"claude-3-5-sonnet" is Anthropic's most powerful AI model and maintains the speed and cost of Claude 3 Sonnet, which is a mid-tier model.
804825
"claude-3-opus" is Anthropic's second-most powerful AI model, with strong performance on highly complex tasks.
805826
https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/use-claude#available-claude-models
806-
Default to "claude-3-sonnet".
827+
If no setting is provided, "claude-3-sonnet" will be used by default
828+
and a warning will be issued.
807829
session (bigframes.Session or None):
808830
BQ session to create the model. If None, use the global default session.
809831
connection_name (str or None):
@@ -815,12 +837,21 @@ class Claude3TextGenerator(base.RetriableRemotePredictor):
815837
def __init__(
816838
self,
817839
*,
818-
model_name: Literal[
819-
"claude-3-sonnet", "claude-3-haiku", "claude-3-5-sonnet", "claude-3-opus"
820-
] = "claude-3-sonnet",
840+
model_name: Optional[
841+
Literal[
842+
"claude-3-sonnet",
843+
"claude-3-haiku",
844+
"claude-3-5-sonnet",
845+
"claude-3-opus",
846+
]
847+
] = None,
821848
session: Optional[bigframes.Session] = None,
822849
connection_name: Optional[str] = None,
823850
):
851+
if model_name is None:
852+
model_name = "claude-3-sonnet"
853+
msg = exceptions.format_message(_REMOVE_DEFAULT_MODEL_WARNING)
854+
warnings.warn(msg, category=FutureWarning, stacklevel=2)
824855
self.model_name = model_name
825856
self.session = session or global_session.get_global_session()
826857
self.connection_name = connection_name

tests/system/small/ml/test_llm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -762,3 +762,19 @@ def test_llm_gemini_pro_score_params(llm_fine_tune_df_default_index, model_name)
762762
def test_gemini_preview_model_warnings(model_name):
763763
with pytest.warns(exceptions.PreviewWarning):
764764
llm.GeminiTextGenerator(model_name=model_name)
765+
766+
767+
@pytest.mark.parametrize(
768+
"model_class",
769+
[
770+
llm.TextEmbeddingGenerator,
771+
llm.MultimodalEmbeddingGenerator,
772+
llm.GeminiTextGenerator,
773+
llm.Claude3TextGenerator,
774+
],
775+
)
776+
def test_text_embedding_generator_no_default_model_warning(model_class):
777+
message = "Since upgrading the default model can cause unintended breakages, the\ndefault model will be removed in BigFrames 3.0. Please supply an\nexplicit model to avoid this message."
778+
bigframes.options.experiments.blob = True
779+
with pytest.warns(FutureWarning, match=message):
780+
model_class(model_name=None)

0 commit comments

Comments
 (0)