1717from __future__ import annotations
1818
1919import datetime
20- from typing import Callable , cast , Iterable , Literal , Mapping , Optional , Union
20+ from typing import Callable , cast , Iterable , Mapping , Optional , Union
2121import uuid
2222
2323from google .cloud import bigquery
@@ -35,11 +35,27 @@ def __init__(self, session: bigframes.Session):
3535 self ._session = session
3636 self ._base_sql_generator = ml_sql .BaseSqlGenerator ()
3737
38- def _apply_sql (
38+
39+ class BqmlModel (BaseBqml ):
40+ """Represents an existing BQML model in BigQuery.
41+
42+ Wraps the BQML API and SQL interface to expose the functionality needed for
43+ BigQuery DataFrames ML.
44+ """
45+
46+ def __init__ (self , session : bigframes .Session , model : bigquery .Model ):
47+ self ._session = session
48+ self ._model = model
49+ self ._model_manipulation_sql_generator = ml_sql .ModelManipulationSqlGenerator (
50+ self .model_name
51+ )
52+
53+ def _apply_ml_tvf (
3954 self ,
4055 input_data : bpd .DataFrame ,
41- func : Callable [[bpd . DataFrame ], str ],
56+ apply_sql_tvf : Callable [[str ], str ],
4257 ) -> bpd .DataFrame :
58+ # Used for predict, transform, distance
4359 """Helper to wrap a dataframe in a SQL query, keeping the index intact.
4460
4561 Args:
@@ -50,67 +66,28 @@ def _apply_sql(
5066 the dataframe to be wrapped
5167
5268 func (function):
53- a function that will accept a SQL string and produce a new SQL
54- string from which to construct the output dataframe. It must
55- include the index columns of the input SQL .
69+ Takes an input sql table value and applies a prediction tvf. The
70+ resulting table value must include all input columns, with new
71+ columns appended to the end .
5672 """
57- _ , index_col_ids , index_labels = input_data ._to_sql_query (include_index = True )
58-
59- sql = func (input_data )
60- df = self ._session .read_gbq (sql , index_col = index_col_ids )
61- df .index .names = index_labels
62-
63- return df
64-
65- def distance (
66- self ,
67- x : bpd .DataFrame ,
68- y : bpd .DataFrame ,
69- type : Literal ["EUCLIDEAN" , "MANHATTAN" , "COSINE" ],
70- name : str ,
71- ) -> bpd .DataFrame :
72- """Calculate ML.DISTANCE from DataFrame inputs.
73-
74- Args:
75- x:
76- input DataFrame
77- y:
78- input DataFrame
79- type:
80- Distance types, accept values are "EUCLIDEAN", "MANHATTAN", "COSINE".
81- name:
82- name of the output result column
83- """
84- assert len (x .columns ) == 1 and len (y .columns ) == 1
85-
86- input_data = x .join (y , how = "outer" ).cache ()
87- x_column_id , y_column_id = x ._block .value_columns [0 ], y ._block .value_columns [0 ]
88-
89- return self ._apply_sql (
90- input_data ,
91- lambda source_df : self ._base_sql_generator .ml_distance (
92- x_column_id ,
93- y_column_id ,
94- type = type ,
95- source_df = source_df ,
96- name = name ,
97- ),
73+ # TODO: Preserve ordering information?
74+ input_sql , index_col_ids , index_labels = input_data ._to_sql_query (
75+ include_index = True
9876 )
9977
100-
101- class BqmlModel (BaseBqml ):
102- """Represents an existing BQML model in BigQuery.
103-
104- Wraps the BQML API and SQL interface to expose the functionality needed for
105- BigQuery DataFrames ML.
106- """
107-
108- def __init__ (self , session : bigframes .Session , model : bigquery .Model ):
109- self ._session = session
110- self ._model = model
111- self ._model_manipulation_sql_generator = ml_sql .ModelManipulationSqlGenerator (
112- self .model_name
78+ result_sql = apply_sql_tvf (input_sql )
79+ df = self ._session .read_gbq (result_sql , index_col = index_col_ids )
80+ df .index .names = index_labels
81+ # Restore column labels
82+ df .rename (
83+ columns = {
84+ label : original_label
85+ for label , original_label in zip (
86+ df .columns .values , input_data .columns .values
87+ )
88+ }
11389 )
90+ return df
11491
11592 def _keys (self ):
11693 return (self ._session , self ._model )
@@ -137,13 +114,13 @@ def model(self) -> bigquery.Model:
137114 return self ._model
138115
139116 def predict (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
140- return self ._apply_sql (
117+ return self ._apply_ml_tvf (
141118 input_data ,
142119 self ._model_manipulation_sql_generator .ml_predict ,
143120 )
144121
145122 def transform (self , input_data : bpd .DataFrame ) -> bpd .DataFrame :
146- return self ._apply_sql (
123+ return self ._apply_ml_tvf (
147124 input_data ,
148125 self ._model_manipulation_sql_generator .ml_transform ,
149126 )
@@ -153,10 +130,10 @@ def generate_text(
153130 input_data : bpd .DataFrame ,
154131 options : Mapping [str , int | float ],
155132 ) -> bpd .DataFrame :
156- return self ._apply_sql (
133+ return self ._apply_ml_tvf (
157134 input_data ,
158- lambda source_df : self ._model_manipulation_sql_generator .ml_generate_text (
159- source_df = source_df ,
135+ lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_text (
136+ source_sql = source_sql ,
160137 struct_options = options ,
161138 ),
162139 )
@@ -166,10 +143,10 @@ def generate_embedding(
166143 input_data : bpd .DataFrame ,
167144 options : Mapping [str , int | float ],
168145 ) -> bpd .DataFrame :
169- return self ._apply_sql (
146+ return self ._apply_ml_tvf (
170147 input_data ,
171- lambda source_df : self ._model_manipulation_sql_generator .ml_generate_embedding (
172- source_df = source_df ,
148+ lambda source_sql : self ._model_manipulation_sql_generator .ml_generate_embedding (
149+ source_sql = source_sql ,
173150 struct_options = options ,
174151 ),
175152 )
@@ -179,10 +156,10 @@ def detect_anomalies(
179156 ) -> bpd .DataFrame :
180157 assert self ._model .model_type in ("PCA" , "KMEANS" , "ARIMA_PLUS" )
181158
182- return self ._apply_sql (
159+ return self ._apply_ml_tvf (
183160 input_data ,
184- lambda source_df : self ._model_manipulation_sql_generator .ml_detect_anomalies (
185- source_df = source_df ,
161+ lambda source_sql : self ._model_manipulation_sql_generator .ml_detect_anomalies (
162+ source_sql = source_sql ,
186163 struct_options = options ,
187164 ),
188165 )
@@ -192,7 +169,9 @@ def forecast(self, options: Mapping[str, int | float]) -> bpd.DataFrame:
192169 return self ._session .read_gbq (sql , index_col = "forecast_timestamp" ).reset_index ()
193170
194171 def evaluate (self , input_data : Optional [bpd .DataFrame ] = None ):
195- sql = self ._model_manipulation_sql_generator .ml_evaluate (input_data )
172+ sql = self ._model_manipulation_sql_generator .ml_evaluate (
173+ input_data .sql if (input_data is not None ) else None
174+ )
196175
197176 return self ._session .read_gbq (sql )
198177
@@ -202,7 +181,7 @@ def llm_evaluate(
202181 task_type : Optional [str ] = None ,
203182 ):
204183 sql = self ._model_manipulation_sql_generator .ml_llm_evaluate (
205- input_data , task_type
184+ input_data . sql , task_type
206185 )
207186
208187 return self ._session .read_gbq (sql )
@@ -336,7 +315,7 @@ def create_model(
336315 model_ref = self ._create_model_ref (session ._anonymous_dataset )
337316
338317 sql = self ._model_creation_sql_generator .create_model (
339- source_df = input_data ,
318+ source_sql = input_data . sql ,
340319 model_ref = model_ref ,
341320 transforms = transforms ,
342321 options = options ,
@@ -374,7 +353,7 @@ def create_llm_remote_model(
374353 model_ref = self ._create_model_ref (session ._anonymous_dataset )
375354
376355 sql = self ._model_creation_sql_generator .create_llm_remote_model (
377- source_df = input_data ,
356+ source_sql = input_data . sql ,
378357 model_ref = model_ref ,
379358 options = options ,
380359 connection_name = connection_name ,
@@ -407,7 +386,7 @@ def create_time_series_model(
407386 model_ref = self ._create_model_ref (session ._anonymous_dataset )
408387
409388 sql = self ._model_creation_sql_generator .create_model (
410- source_df = input_data ,
389+ source_sql = input_data . sql ,
411390 model_ref = model_ref ,
412391 transforms = transforms ,
413392 options = options ,
0 commit comments