Skip to content

Commit f6282b0

Browse files
authored
chore: implement semantics cluster_by (#1067)
* chore: implement semantics cluster_by * address comments and fix tests
1 parent 99ca0df commit f6282b0

File tree

3 files changed

+342
-33
lines changed

3 files changed

+342
-33
lines changed

bigframes/operations/semantics.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,88 @@ def agg(
194194

195195
return df[column]
196196

197+
def cluster_by(
198+
self,
199+
column: str,
200+
output_column: str,
201+
model,
202+
n_clusters: int = 5,
203+
):
204+
"""
205+
Clusters data based on the semantic similarity of text within a specified column.
206+
207+
This method leverages a language model to generate text embeddings for each value in
208+
the given column. These embeddings capture the semantic meaning of the text.
209+
The data is then grouped into `n` clusters using the k-means clustering algorithm,
210+
which groups data points based on the similarity of their embeddings.
211+
212+
**Examples:**
213+
214+
>>> import bigframes.pandas as bpd
215+
>>> bpd.options.display.progress_bar = None
216+
>>> bpd.options.experiments.semantic_operators = True
217+
218+
>>> import bigframes.ml.llm as llm
219+
>>> model = llm.TextEmbeddingGenerator()
220+
221+
>>> df = bpd.DataFrame({
222+
... "Product": ["Smartphone", "Laptop", "T-shirt", "Jeans"],
223+
... })
224+
>>> df.semantics.cluster_by("Product", "Cluster ID", model, n_clusters=2)
225+
Product Cluster ID
226+
0 Smartphone 2
227+
1 Laptop 2
228+
2 T-shirt 1
229+
3 Jeans 1
230+
<BLANKLINE>
231+
[4 rows x 2 columns]
232+
233+
Args:
234+
column (str):
235+
An column name to perform the similarity clustering.
236+
237+
output_column (str):
238+
An output column to store the clustering ID.
239+
240+
model (bigframes.ml.llm.TextEmbeddingGenerator):
241+
A TextEmbeddingGenerator provided by Bigframes ML package.
242+
243+
n_clusters (int, default 5):
244+
Default 5. Number of clusters to be detected.
245+
246+
Returns:
247+
bigframes.dataframe.DataFrame: A new DataFrame with the clustering output column.
248+
249+
Raises:
250+
NotImplementedError: when the semantic operator experiment is off.
251+
ValueError: when the column refers to a non-existing column.
252+
"""
253+
254+
import bigframes.dataframe
255+
import bigframes.ml.cluster as cluster
256+
import bigframes.ml.llm as llm
257+
258+
if not isinstance(model, llm.TextEmbeddingGenerator):
259+
raise TypeError(f"Expect a text embedding model, but got: {type(model)}")
260+
261+
if column not in self._df.columns:
262+
raise ValueError(f"Column {column} not found.")
263+
264+
if n_clusters <= 1:
265+
raise ValueError(
266+
f"Invalid value for `n_clusters`: {n_clusters}."
267+
"It must be greater than 1."
268+
)
269+
270+
df: bigframes.dataframe.DataFrame = self._df.copy()
271+
embeddings_df = model.predict(df[column])
272+
273+
cluster_model = cluster.KMeans(n_clusters=n_clusters)
274+
cluster_model.fit(embeddings_df[["ml_generate_embedding_result"]])
275+
clustered_result = cluster_model.predict(embeddings_df)
276+
df[output_column] = clustered_result["CENTROID_ID"]
277+
return df
278+
197279
def filter(self, instruction: str, model):
198280
"""
199281
Filters the DataFrame with the semantics of the user instruction.

notebooks/experimental/semantic_operators.ipynb

Lines changed: 198 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
"name": "stderr",
3434
"output_type": "stream",
3535
"text": [
36-
"/usr/local/google/home/chelsealin/src/bigframes3/bigframes/_config/experiment_options.py:33: UserWarning: Semantic operators are still under experiments, and are subject to change in the future.\n",
36+
"/usr/local/google/home/chelsealin/src/bigframes/bigframes/_config/experiment_options.py:33: UserWarning: Semantic operators are still under experiments, and are subject to change in the future.\n",
3737
" warnings.warn(\n"
3838
]
3939
}
@@ -51,21 +51,25 @@
5151
},
5252
{
5353
"cell_type": "code",
54-
"execution_count": 3,
54+
"execution_count": 4,
5555
"metadata": {},
5656
"outputs": [
5757
{
58-
"name": "stderr",
59-
"output_type": "stream",
60-
"text": [
61-
"/usr/local/google/home/chelsealin/src/bigframes3/bigframes/pandas/__init__.py:559: DefaultLocationWarning: No explicit location is set, so using location US for the session.\n",
62-
" return global_session.get_global_session()\n"
63-
]
58+
"data": {
59+
"text/html": [
60+
"Query job 13e4b10e-70cf-4b93-8c59-5f6f5fb10aeb is DONE. 0 Bytes processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:13e4b10e-70cf-4b93-8c59-5f6f5fb10aeb&page=queryresults\">Open Job</a>"
61+
],
62+
"text/plain": [
63+
"<IPython.core.display.HTML object>"
64+
]
65+
},
66+
"metadata": {},
67+
"output_type": "display_data"
6468
},
6569
{
6670
"data": {
6771
"text/html": [
68-
"Query job aef2dd7b-bdad-4dda-91be-867e8dac2613 is DONE. 0 Bytes processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:aef2dd7b-bdad-4dda-91be-867e8dac2613&page=queryresults\">Open Job</a>"
72+
"Query job 559dd42c-573d-4b00-8fe9-b7061afdd672 is DONE. 0 Bytes processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:559dd42c-573d-4b00-8fe9-b7061afdd672&page=queryresults\">Open Job</a>"
6973
],
7074
"text/plain": [
7175
"<IPython.core.display.HTML object>"
@@ -77,7 +81,8 @@
7781
],
7882
"source": [
7983
"import bigframes.ml.llm as llm\n",
80-
"gemini_model = llm.GeminiTextGenerator(model_name=llm._GEMINI_1P5_FLASH_001_ENDPOINT)"
84+
"gemini_model = llm.GeminiTextGenerator(model_name=llm._GEMINI_1P5_FLASH_001_ENDPOINT)\n",
85+
"text_embedding_model = llm.TextEmbeddingGenerator(model_name=\"text-embedding-004\")"
8186
]
8287
},
8388
{
@@ -657,28 +662,6 @@
657662
"## Semantic Search"
658663
]
659664
},
660-
{
661-
"cell_type": "code",
662-
"execution_count": 11,
663-
"metadata": {},
664-
"outputs": [
665-
{
666-
"data": {
667-
"text/html": [
668-
"Query job 48aafee2-4948-4677-ab02-a94a71b9f6e2 is DONE. 0 Bytes processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:48aafee2-4948-4677-ab02-a94a71b9f6e2&page=queryresults\">Open Job</a>"
669-
],
670-
"text/plain": [
671-
"<IPython.core.display.HTML object>"
672-
]
673-
},
674-
"metadata": {},
675-
"output_type": "display_data"
676-
}
677-
],
678-
"source": [
679-
"text_embedding_model = llm.TextEmbeddingGenerator(model_name=\"text-embedding-004\")"
680-
]
681-
},
682665
{
683666
"cell_type": "code",
684667
"execution_count": 12,
@@ -1156,6 +1139,188 @@
11561139
"agg_df = df.semantics.agg(\"Find the shared first name of actors in {Movies}. One word answer.\", model=gemini_model)\n",
11571140
"agg_df"
11581141
]
1142+
},
1143+
{
1144+
"cell_type": "markdown",
1145+
"metadata": {},
1146+
"source": [
1147+
"## Semantic Cluster"
1148+
]
1149+
},
1150+
{
1151+
"cell_type": "code",
1152+
"execution_count": 5,
1153+
"metadata": {},
1154+
"outputs": [
1155+
{
1156+
"data": {
1157+
"text/html": [
1158+
"Query job 92ce82b9-c521-42af-a2b7-6114b27a9ce4 is DONE. 0 Bytes processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:92ce82b9-c521-42af-a2b7-6114b27a9ce4&page=queryresults\">Open Job</a>"
1159+
],
1160+
"text/plain": [
1161+
"<IPython.core.display.HTML object>"
1162+
]
1163+
},
1164+
"metadata": {},
1165+
"output_type": "display_data"
1166+
},
1167+
{
1168+
"name": "stderr",
1169+
"output_type": "stream",
1170+
"text": [
1171+
"/usr/local/google/home/chelsealin/src/bigframes/bigframes/core/__init__.py:112: PreviewWarning: Interpreting JSON column(s) as StringDtype. This behavior may change in future versions.\n",
1172+
" warnings.warn(\n"
1173+
]
1174+
},
1175+
{
1176+
"data": {
1177+
"text/html": [
1178+
"Query job 8c4c7391-2889-4cf1-bbfa-5cbf6b144db5 is DONE. 10 Bytes processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:8c4c7391-2889-4cf1-bbfa-5cbf6b144db5&page=queryresults\">Open Job</a>"
1179+
],
1180+
"text/plain": [
1181+
"<IPython.core.display.HTML object>"
1182+
]
1183+
},
1184+
"metadata": {},
1185+
"output_type": "display_data"
1186+
},
1187+
{
1188+
"data": {
1189+
"text/html": [
1190+
"Query job 19ae7cc6-3d61-4c69-9148-1956fafb577a is DONE. 30.8 kB processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:19ae7cc6-3d61-4c69-9148-1956fafb577a&page=queryresults\">Open Job</a>"
1191+
],
1192+
"text/plain": [
1193+
"<IPython.core.display.HTML object>"
1194+
]
1195+
},
1196+
"metadata": {},
1197+
"output_type": "display_data"
1198+
},
1199+
{
1200+
"data": {
1201+
"text/html": [
1202+
"Query job 7c2b62df-3bed-4469-9ffc-131843efe25e is DONE. 30.7 kB processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:7c2b62df-3bed-4469-9ffc-131843efe25e&page=queryresults\">Open Job</a>"
1203+
],
1204+
"text/plain": [
1205+
"<IPython.core.display.HTML object>"
1206+
]
1207+
},
1208+
"metadata": {},
1209+
"output_type": "display_data"
1210+
},
1211+
{
1212+
"data": {
1213+
"text/html": [
1214+
"Query job 74155e34-d8ca-4fba-8b93-33b1b325a5f1 is DONE. 138.9 kB processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:74155e34-d8ca-4fba-8b93-33b1b325a5f1&page=queryresults\">Open Job</a>"
1215+
],
1216+
"text/plain": [
1217+
"<IPython.core.display.HTML object>"
1218+
]
1219+
},
1220+
"metadata": {},
1221+
"output_type": "display_data"
1222+
},
1223+
{
1224+
"data": {
1225+
"text/html": [
1226+
"Query job d9151043-a9c3-4388-8268-ef41162012b7 is DONE. 80 Bytes processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:d9151043-a9c3-4388-8268-ef41162012b7&page=queryresults\">Open Job</a>"
1227+
],
1228+
"text/plain": [
1229+
"<IPython.core.display.HTML object>"
1230+
]
1231+
},
1232+
"metadata": {},
1233+
"output_type": "display_data"
1234+
},
1235+
{
1236+
"data": {
1237+
"text/html": [
1238+
"Query job d2c4ad9a-c637-490e-a2cf-37d7f5a34024 is DONE. 170 Bytes processed. <a target=\"_blank\" href=\"https://console.cloud.google.com/bigquery?project=bigframes-dev&j=bq:US:d2c4ad9a-c637-490e-a2cf-37d7f5a34024&page=queryresults\">Open Job</a>"
1239+
],
1240+
"text/plain": [
1241+
"<IPython.core.display.HTML object>"
1242+
]
1243+
},
1244+
"metadata": {},
1245+
"output_type": "display_data"
1246+
},
1247+
{
1248+
"data": {
1249+
"text/html": [
1250+
"<div>\n",
1251+
"<style scoped>\n",
1252+
" .dataframe tbody tr th:only-of-type {\n",
1253+
" vertical-align: middle;\n",
1254+
" }\n",
1255+
"\n",
1256+
" .dataframe tbody tr th {\n",
1257+
" vertical-align: top;\n",
1258+
" }\n",
1259+
"\n",
1260+
" .dataframe thead th {\n",
1261+
" text-align: right;\n",
1262+
" }\n",
1263+
"</style>\n",
1264+
"<table border=\"1\" class=\"dataframe\">\n",
1265+
" <thead>\n",
1266+
" <tr style=\"text-align: right;\">\n",
1267+
" <th></th>\n",
1268+
" <th>Product</th>\n",
1269+
" <th>Cluster ID</th>\n",
1270+
" </tr>\n",
1271+
" </thead>\n",
1272+
" <tbody>\n",
1273+
" <tr>\n",
1274+
" <th>0</th>\n",
1275+
" <td>Smartphone</td>\n",
1276+
" <td>3</td>\n",
1277+
" </tr>\n",
1278+
" <tr>\n",
1279+
" <th>1</th>\n",
1280+
" <td>Laptop</td>\n",
1281+
" <td>3</td>\n",
1282+
" </tr>\n",
1283+
" <tr>\n",
1284+
" <th>2</th>\n",
1285+
" <td>Coffee Maker</td>\n",
1286+
" <td>1</td>\n",
1287+
" </tr>\n",
1288+
" <tr>\n",
1289+
" <th>3</th>\n",
1290+
" <td>T-shirt</td>\n",
1291+
" <td>2</td>\n",
1292+
" </tr>\n",
1293+
" <tr>\n",
1294+
" <th>4</th>\n",
1295+
" <td>Jeans</td>\n",
1296+
" <td>2</td>\n",
1297+
" </tr>\n",
1298+
" </tbody>\n",
1299+
"</table>\n",
1300+
"<p>5 rows × 2 columns</p>\n",
1301+
"</div>[5 rows x 2 columns in total]"
1302+
],
1303+
"text/plain": [
1304+
" Product Cluster ID\n",
1305+
"0 Smartphone 3\n",
1306+
"1 Laptop 3\n",
1307+
"2 Coffee Maker 1\n",
1308+
"3 T-shirt 2\n",
1309+
"4 Jeans 2\n",
1310+
"\n",
1311+
"[5 rows x 2 columns]"
1312+
]
1313+
},
1314+
"execution_count": 5,
1315+
"metadata": {},
1316+
"output_type": "execute_result"
1317+
}
1318+
],
1319+
"source": [
1320+
"df = bpd.DataFrame({'Product': ['Smartphone', 'Laptop', 'Coffee Maker', 'T-shirt', 'Jeans']})\n",
1321+
"\n",
1322+
"df.semantics.cluster_by(column='Product', output_column='Cluster ID', model=text_embedding_model, n=3)"
1323+
]
11591324
}
11601325
],
11611326
"metadata": {
@@ -1174,7 +1339,7 @@
11741339
"name": "python",
11751340
"nbconvert_exporter": "python",
11761341
"pygments_lexer": "ipython3",
1177-
"version": "3.11.9"
1342+
"version": "3.12.1"
11781343
}
11791344
},
11801345
"nbformat": 4,

0 commit comments

Comments
 (0)