Skip to content

Commit b1bf65d

Browse files
MrPresent-HanMrPresent-HanXuanYang-cn
authored
feat: support query group by(#3177) (#3178)
related: #3177 Signed-off-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: MrPresent-Han <chun.han@gmail.com> Co-authored-by: XuanYang-cn <xuan.yang@zilliz.com>
1 parent 5d377b2 commit b1bf65d

File tree

5 files changed

+359
-1
lines changed

5 files changed

+359
-1
lines changed
Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
import numpy as np
2+
from pymilvus import (
3+
connections,
4+
utility,
5+
FieldSchema, CollectionSchema, DataType,
6+
Collection)
7+
from collections import Counter
8+
from datetime import datetime, timezone
9+
import random
10+
11+
names = ["Green", "Rachel", "Joe", "Chandler", "Phebe", "Ross", "Monica"]
12+
collection_name = 'test_query_group_by'
13+
clean_exist = False
14+
prepare_data = False
15+
to_flush = True
16+
batch_num = 3
17+
num_entities, dim = 122, 8
18+
fmt = "\n=== {:30} ===\n"
19+
SHOW_STATS_DETAILS = False
20+
21+
print(fmt.format("start connecting to Milvus"))
22+
connections.connect("default", host="localhost", port="19530")
23+
24+
if clean_exist and utility.has_collection(collection_name):
25+
utility.drop_collection(collection_name)
26+
27+
TS = "ts"
28+
fields = [
29+
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
30+
FieldSchema(name="c1", dtype=DataType.VARCHAR, max_length=512),
31+
FieldSchema(name="c2", dtype=DataType.INT16),
32+
FieldSchema(name="c3", dtype=DataType.INT32),
33+
FieldSchema(name="c4", dtype=DataType.DOUBLE),
34+
FieldSchema(name=TS, dtype=DataType.TIMESTAMPTZ, description="timestamp with timezone"),
35+
FieldSchema(name="c5", dtype=DataType.FLOAT_VECTOR, dim=dim),
36+
FieldSchema(name="c6", dtype=DataType.VARCHAR, max_length=512),
37+
]
38+
39+
schema = CollectionSchema(fields)
40+
41+
print(fmt.format(f"Create collection `{collection_name}`"))
42+
collection = Collection(collection_name, schema, consistency_level="Strong")
43+
44+
if prepare_data:
45+
rng = np.random.default_rng(seed=19530)
46+
print(fmt.format("Start inserting entities"))
47+
48+
# Keep a small cardinality for TS so GROUP BY results are readable.
49+
ts_choices = [
50+
datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc).isoformat(),
51+
datetime(2025, 1, 1, 1, 0, 0, tzinfo=timezone.utc).isoformat(),
52+
datetime(2025, 1, 1, 2, 0, 0, tzinfo=timezone.utc).isoformat(),
53+
]
54+
55+
# Minimal stats (avoid printing too much)
56+
pair_counter = Counter()
57+
c2_sum = 0
58+
c3_sum = 0
59+
c4_sum = 0.0
60+
c2_lt_2 = 0
61+
c2_dist = Counter()
62+
63+
for batch_idx in range(batch_num):
64+
# generate per-batch data
65+
c1_data = [random.choice(names) for _ in range(num_entities)]
66+
c2_data = [random.randint(0, 4) for _ in range(num_entities)]
67+
c3_data = [random.randint(0, 6) for _ in range(num_entities)]
68+
c4_data = [random.uniform(0.0, 100.0) for _ in range(num_entities)]
69+
c6_data = [random.choice(names) for _ in range(num_entities)]
70+
ts_data = [ts_choices[(batch_idx + j) % len(ts_choices)] for j in range(num_entities)]
71+
72+
# collect minimal stats
73+
pair_counter.update(zip(c1_data, c6_data))
74+
c2_sum += sum(c2_data)
75+
c3_sum += sum(c3_data)
76+
c4_sum += sum(c4_data)
77+
c2_lt_2 += sum(1 for v in c2_data if v < 2)
78+
c2_dist.update(c2_data)
79+
80+
entities = [
81+
[str(i) for i in range(num_entities * batch_idx, num_entities * (batch_idx + 1))],
82+
c1_data,
83+
c2_data,
84+
c3_data,
85+
c4_data,
86+
ts_data,
87+
rng.random((num_entities, dim)),
88+
c6_data,
89+
]
90+
collection.insert(entities)
91+
if to_flush:
92+
print(f"flush batch:{batch_idx}")
93+
collection.flush()
94+
print(f"inserted batch:{batch_idx}")
95+
96+
total_rows = batch_num * num_entities
97+
print(fmt.format("Quick stats (compact)"))
98+
print(f"total_rows: {total_rows}")
99+
print(f"unique (c1,c6) pairs: {len(pair_counter)}")
100+
print("top 5 (c1,c6) pairs:")
101+
for (c1v, c6v), cnt in pair_counter.most_common(5):
102+
print(f" ({c1v}, {c6v}): {cnt}")
103+
print(f"sum(c2): {c2_sum}, sum(c3): {c3_sum}, sum(c4): {c4_sum:.2f}")
104+
print(f"c2 < 2: {c2_lt_2} ({c2_lt_2/total_rows*100:.2f}%)")
105+
if SHOW_STATS_DETAILS:
106+
print(f"c2 distribution: {dict(sorted(c2_dist.items()))}")
107+
108+
109+
print(fmt.format("Start Creating index IVF_FLAT"))
110+
index = {
111+
"index_type": "IVF_FLAT",
112+
"metric_type": "L2",
113+
"params": {"nlist": 128},
114+
}
115+
collection.create_index("c5", index)
116+
print(f"Number of entities in Milvus: {collection.num_entities}") # check the num_entities
117+
collection.load()
118+
119+
120+
#1. group by TIMESTAMPTZ + max
121+
print(fmt.format("Query: group by TIMESTAMPTZ + max"))
122+
res_ts = collection.query(
123+
expr="c2 < 10",
124+
group_by_fields=[TS],
125+
output_fields=[TS, "count(c2)", f"max({TS})"],
126+
timeout=120.0,
127+
)
128+
for row in res_ts:
129+
print(f"res={row}")
130+
131+
#2. group by (c1,c6) + min/max
132+
print(fmt.format("Query: group by (c1,c6) + min/max"))
133+
res_minmax = collection.query(
134+
expr="c2 < 10",
135+
group_by_fields=["c1", "c6"],
136+
output_fields=["c1", "c6", "min(c2)", "max(c2)"],
137+
timeout=120.0,
138+
)
139+
for row in res_minmax:
140+
print(f"res={row}")
141+
142+
143+
#3. group by c1 + avg(c2, c3, c4)
144+
print(fmt.format("Query: group by c1 + avg(c2, c3, c4)"))
145+
res_avg = collection.query(
146+
expr="c2 < 10",
147+
group_by_fields=["c1"],
148+
output_fields=["c1", "avg(c2)", "avg(c3)", "avg(c4)"],
149+
timeout=120.0,
150+
)
151+
for row in res_avg:
152+
print(f"res={row}")
153+
154+
#4. group by c1 + avg(c2, c3, c4) without expr
155+
print(fmt.format("Query: group by c1 + avg(c2, c3, c4)"))
156+
res_avg = collection.query(
157+
expr="",
158+
group_by_fields=["c1"],
159+
output_fields=["c1", "avg(c2)", "avg(c3)", "avg(c4)"],
160+
timeout=120.0,
161+
limit=10,
162+
)
163+
for row in res_avg:
164+
print(f"res={row}")

examples/query_group_by.py

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import numpy as np
2+
from pymilvus import (
3+
FieldSchema, CollectionSchema, DataType,
4+
)
5+
from pymilvus.milvus_client import MilvusClient
6+
from collections import Counter
7+
from datetime import datetime, timezone
8+
import random
9+
10+
names = ["Green", "Rachel", "Joe", "Chandler", "Phebe", "Ross", "Monica"]
11+
collection_name = 'test_query_group_by'
12+
clean_exist = False
13+
prepare_data = False
14+
to_flush = True
15+
batch_num = 3
16+
num_entities, dim = 122, 8
17+
fmt = "\n=== {:30} ===\n"
18+
SHOW_STATS_DETAILS = False
19+
20+
print(fmt.format("start connecting to Milvus"))
21+
client = MilvusClient(uri="http://localhost:19530")
22+
23+
if clean_exist and client.has_collection(collection_name):
24+
client.drop_collection(collection_name)
25+
26+
TS = "ts"
27+
fields = [
28+
FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100),
29+
FieldSchema(name="c1", dtype=DataType.VARCHAR, max_length=512),
30+
FieldSchema(name="c2", dtype=DataType.INT16),
31+
FieldSchema(name="c3", dtype=DataType.INT32),
32+
FieldSchema(name="c4", dtype=DataType.DOUBLE),
33+
FieldSchema(name=TS, dtype=DataType.TIMESTAMPTZ, description="timestamp with timezone"),
34+
FieldSchema(name="c5", dtype=DataType.FLOAT_VECTOR, dim=dim),
35+
FieldSchema(name="c6", dtype=DataType.VARCHAR, max_length=512),
36+
]
37+
38+
schema = CollectionSchema(fields)
39+
40+
print(fmt.format(f"Create collection `{collection_name}`"))
41+
client.create_collection(
42+
collection_name=collection_name,
43+
schema=schema,
44+
consistency_level="Strong"
45+
)
46+
47+
if prepare_data:
48+
rng = np.random.default_rng(seed=19530)
49+
print(fmt.format("Start inserting entities"))
50+
51+
# Keep a small cardinality for TS so GROUP BY results are readable.
52+
ts_choices = [
53+
datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc).isoformat(),
54+
datetime(2025, 1, 1, 1, 0, 0, tzinfo=timezone.utc).isoformat(),
55+
datetime(2025, 1, 1, 2, 0, 0, tzinfo=timezone.utc).isoformat(),
56+
]
57+
58+
# Minimal stats (avoid printing too much)
59+
pair_counter = Counter()
60+
c2_sum = 0
61+
c3_sum = 0
62+
c4_sum = 0.0
63+
c2_lt_2 = 0
64+
c2_dist = Counter()
65+
66+
for batch_idx in range(batch_num):
67+
# generate per-batch data
68+
c1_data = [random.choice(names) for _ in range(num_entities)]
69+
c2_data = [random.randint(0, 4) for _ in range(num_entities)]
70+
c3_data = [random.randint(0, 6) for _ in range(num_entities)]
71+
c4_data = [random.uniform(0.0, 100.0) for _ in range(num_entities)]
72+
c6_data = [random.choice(names) for _ in range(num_entities)]
73+
ts_data = [ts_choices[(batch_idx + j) % len(ts_choices)] for j in range(num_entities)]
74+
vector_data = rng.random((num_entities, dim))
75+
76+
# collect minimal stats
77+
pair_counter.update(zip(c1_data, c6_data))
78+
c2_sum += sum(c2_data)
79+
c3_sum += sum(c3_data)
80+
c4_sum += sum(c4_data)
81+
c2_lt_2 += sum(1 for v in c2_data if v < 2)
82+
c2_dist.update(c2_data)
83+
84+
# Convert to dict format for milvus_client.insert()
85+
data = [
86+
{
87+
"pk": str(i),
88+
"c1": c1_data[i],
89+
"c2": c2_data[i],
90+
"c3": c3_data[i],
91+
"c4": c4_data[i],
92+
TS: ts_data[i],
93+
"c5": vector_data[i].tolist(),
94+
"c6": c6_data[i],
95+
}
96+
for i in range(num_entities)
97+
]
98+
99+
client.insert(collection_name, data)
100+
if to_flush:
101+
print(f"flush batch:{batch_idx}")
102+
client.flush(collection_name)
103+
print(f"inserted batch:{batch_idx}")
104+
105+
total_rows = batch_num * num_entities
106+
print(fmt.format("Quick stats (compact)"))
107+
print(f"total_rows: {total_rows}")
108+
print(f"unique (c1,c6) pairs: {len(pair_counter)}")
109+
print("top 5 (c1,c6) pairs:")
110+
for (c1v, c6v), cnt in pair_counter.most_common(5):
111+
print(f" ({c1v}, {c6v}): {cnt}")
112+
print(f"sum(c2): {c2_sum}, sum(c3): {c3_sum}, sum(c4): {c4_sum:.2f}")
113+
print(f"c2 < 2: {c2_lt_2} ({c2_lt_2/total_rows*100:.2f}%)")
114+
if SHOW_STATS_DETAILS:
115+
print(f"c2 distribution: {dict(sorted(c2_dist.items()))}")
116+
117+
118+
print(fmt.format("Start Creating index IVF_FLAT"))
119+
from pymilvus.milvus_client.index import IndexParams
120+
index_params = IndexParams()
121+
index_params.add_index("c5", index_type="IVF_FLAT", metric_type="L2", nlist=128)
122+
client.create_index(collection_name, index_params)
123+
124+
stats = client.get_collection_stats(collection_name)
125+
print(f"Number of entities in Milvus: {stats.get('row_count', 0)}") # check the num_entities
126+
client.load_collection(collection_name)
127+
128+
129+
#1. group by TIMESTAMPTZ + max
130+
print(fmt.format("Query: group by TIMESTAMPTZ + max"))
131+
res_ts = client.query(
132+
collection_name=collection_name,
133+
filter="c2 < 10",
134+
output_fields=[TS, "count(c2)", f"max({TS})"],
135+
timeout=120.0,
136+
group_by_fields=[TS],
137+
)
138+
for row in res_ts:
139+
print(f"res={row}")
140+
141+
#2. group by (c1,c6) + min/max
142+
print(fmt.format("Query: group by (c1,c6) + min/max"))
143+
res_minmax = client.query(
144+
collection_name=collection_name,
145+
filter="c2 < 10",
146+
output_fields=["c1", "c6", "min(c2)", "max(c2)"],
147+
timeout=120.0,
148+
group_by_fields=["c1", "c6"],
149+
)
150+
for row in res_minmax:
151+
print(f"res={row}")
152+
153+
154+
#3. group by c1 + avg(c2, c3, c4)
155+
print(fmt.format("Query: group by c1 + avg(c2, c3, c4)"))
156+
res_avg = client.query(
157+
collection_name=collection_name,
158+
filter="c2 < 10",
159+
output_fields=["c1", "avg(c2)", "avg(c3)", "avg(c4)"],
160+
timeout=120.0,
161+
group_by_fields=["c1"],
162+
)
163+
for row in res_avg:
164+
print(f"res={row}")
165+
166+
#4. group by c1 + avg(c2, c3, c4) without expr
167+
print(fmt.format("Query: group by c1 + avg(c2, c3, c4)"))
168+
res_avg = client.query(
169+
collection_name=collection_name,
170+
filter="",
171+
output_fields=["c1", "avg(c2)", "avg(c3)", "avg(c4)"],
172+
timeout=120.0,
173+
limit=10,
174+
group_by_fields=["c1"],
175+
)
176+
for row in res_avg:
177+
print(f"res={row}")

pymilvus/client/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
JSON_TYPE = "json_type"
2020
STRICT_CAST = "strict_cast"
2121
ITERATOR_FIELD = "iterator"
22+
QUERY_GROUP_BY_FIELDS = "group_by_fields"
2223
ITERATOR_SESSION_TS_FIELD = "iterator_session_ts"
2324
ITER_SEARCH_V2_KEY = "search_iter_v2"
2425
ITER_SEARCH_BATCH_SIZE_KEY = "search_iter_batch_size"

pymilvus/client/entity_helper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -697,11 +697,12 @@ def entity_to_field_data(entity: Dict, field_info: Any, num_rows: int) -> schema
697697
field_data.scalars.int_data.data.extend(entity_values)
698698
elif entity_type == DataType.INT64:
699699
field_data.scalars.long_data.data.extend(entity_values)
700+
elif entity_type == DataType.TIMESTAMPTZ:
701+
field_data.scalars.string_data.data.extend(entity_values)
700702
elif entity_type == DataType.FLOAT:
701703
field_data.scalars.float_data.data.extend(entity_values)
702704
elif entity_type == DataType.DOUBLE:
703705
field_data.scalars.double_data.data.extend(entity_values)
704-
705706
elif entity_type == DataType.FLOAT_VECTOR:
706707
if len(entity_values) > 0:
707708
field_data.vectors.dim = len(entity_values[0])

pymilvus/client/prepare.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
JSON_PATH,
4242
JSON_TYPE,
4343
PAGE_RETAIN_ORDER_FIELD,
44+
QUERY_GROUP_BY_FIELDS,
4445
RANK_GROUP_SCORER,
4546
REDUCE_STOP_FOR_BEST,
4647
STRICT_CAST,
@@ -2074,6 +2075,20 @@ def query_request(
20742075
req.query_params.append(
20752076
common_types.KeyValuePair(key=REDUCE_STOP_FOR_BEST, value=str(stop_reduce_for_best))
20762077
)
2078+
2079+
# parse query group-by fields
2080+
query_group_by_fields = kwargs.get(QUERY_GROUP_BY_FIELDS, [])
2081+
if not isinstance(query_group_by_fields, list):
2082+
msg = "group_by_fields must be a list"
2083+
raise TypeError(msg)
2084+
if len(query_group_by_fields) > 0:
2085+
query_group_by_fields_str = ",".join(query_group_by_fields)
2086+
req.query_params.append(
2087+
common_types.KeyValuePair(
2088+
key=QUERY_GROUP_BY_FIELDS, value=query_group_by_fields_str
2089+
)
2090+
)
2091+
20772092
return req
20782093

20792094
@classmethod

0 commit comments

Comments
 (0)