|
| 1 | +import numpy as np |
| 2 | +from pymilvus import ( |
| 3 | + FieldSchema, CollectionSchema, DataType, |
| 4 | +) |
| 5 | +from pymilvus.milvus_client import MilvusClient |
| 6 | +from collections import Counter |
| 7 | +from datetime import datetime, timezone |
| 8 | +import random |
| 9 | + |
| 10 | +names = ["Green", "Rachel", "Joe", "Chandler", "Phebe", "Ross", "Monica"] |
| 11 | +collection_name = 'test_query_group_by' |
| 12 | +clean_exist = False |
| 13 | +prepare_data = False |
| 14 | +to_flush = True |
| 15 | +batch_num = 3 |
| 16 | +num_entities, dim = 122, 8 |
| 17 | +fmt = "\n=== {:30} ===\n" |
| 18 | +SHOW_STATS_DETAILS = False |
| 19 | + |
| 20 | +print(fmt.format("start connecting to Milvus")) |
| 21 | +client = MilvusClient(uri="http://localhost:19530") |
| 22 | + |
| 23 | +if clean_exist and client.has_collection(collection_name): |
| 24 | + client.drop_collection(collection_name) |
| 25 | + |
| 26 | +TS = "ts" |
| 27 | +fields = [ |
| 28 | + FieldSchema(name="pk", dtype=DataType.VARCHAR, is_primary=True, auto_id=False, max_length=100), |
| 29 | + FieldSchema(name="c1", dtype=DataType.VARCHAR, max_length=512), |
| 30 | + FieldSchema(name="c2", dtype=DataType.INT16), |
| 31 | + FieldSchema(name="c3", dtype=DataType.INT32), |
| 32 | + FieldSchema(name="c4", dtype=DataType.DOUBLE), |
| 33 | + FieldSchema(name=TS, dtype=DataType.TIMESTAMPTZ, description="timestamp with timezone"), |
| 34 | + FieldSchema(name="c5", dtype=DataType.FLOAT_VECTOR, dim=dim), |
| 35 | + FieldSchema(name="c6", dtype=DataType.VARCHAR, max_length=512), |
| 36 | +] |
| 37 | + |
| 38 | +schema = CollectionSchema(fields) |
| 39 | + |
| 40 | +print(fmt.format(f"Create collection `{collection_name}`")) |
| 41 | +client.create_collection( |
| 42 | + collection_name=collection_name, |
| 43 | + schema=schema, |
| 44 | + consistency_level="Strong" |
| 45 | +) |
| 46 | + |
| 47 | +if prepare_data: |
| 48 | + rng = np.random.default_rng(seed=19530) |
| 49 | + print(fmt.format("Start inserting entities")) |
| 50 | + |
| 51 | + # Keep a small cardinality for TS so GROUP BY results are readable. |
| 52 | + ts_choices = [ |
| 53 | + datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc).isoformat(), |
| 54 | + datetime(2025, 1, 1, 1, 0, 0, tzinfo=timezone.utc).isoformat(), |
| 55 | + datetime(2025, 1, 1, 2, 0, 0, tzinfo=timezone.utc).isoformat(), |
| 56 | + ] |
| 57 | + |
| 58 | + # Minimal stats (avoid printing too much) |
| 59 | + pair_counter = Counter() |
| 60 | + c2_sum = 0 |
| 61 | + c3_sum = 0 |
| 62 | + c4_sum = 0.0 |
| 63 | + c2_lt_2 = 0 |
| 64 | + c2_dist = Counter() |
| 65 | + |
| 66 | + for batch_idx in range(batch_num): |
| 67 | + # generate per-batch data |
| 68 | + c1_data = [random.choice(names) for _ in range(num_entities)] |
| 69 | + c2_data = [random.randint(0, 4) for _ in range(num_entities)] |
| 70 | + c3_data = [random.randint(0, 6) for _ in range(num_entities)] |
| 71 | + c4_data = [random.uniform(0.0, 100.0) for _ in range(num_entities)] |
| 72 | + c6_data = [random.choice(names) for _ in range(num_entities)] |
| 73 | + ts_data = [ts_choices[(batch_idx + j) % len(ts_choices)] for j in range(num_entities)] |
| 74 | + vector_data = rng.random((num_entities, dim)) |
| 75 | + |
| 76 | + # collect minimal stats |
| 77 | + pair_counter.update(zip(c1_data, c6_data)) |
| 78 | + c2_sum += sum(c2_data) |
| 79 | + c3_sum += sum(c3_data) |
| 80 | + c4_sum += sum(c4_data) |
| 81 | + c2_lt_2 += sum(1 for v in c2_data if v < 2) |
| 82 | + c2_dist.update(c2_data) |
| 83 | + |
| 84 | + # Convert to dict format for milvus_client.insert() |
| 85 | + data = [ |
| 86 | + { |
| 87 | + "pk": str(i), |
| 88 | + "c1": c1_data[i], |
| 89 | + "c2": c2_data[i], |
| 90 | + "c3": c3_data[i], |
| 91 | + "c4": c4_data[i], |
| 92 | + TS: ts_data[i], |
| 93 | + "c5": vector_data[i].tolist(), |
| 94 | + "c6": c6_data[i], |
| 95 | + } |
| 96 | + for i in range(num_entities) |
| 97 | + ] |
| 98 | + |
| 99 | + client.insert(collection_name, data) |
| 100 | + if to_flush: |
| 101 | + print(f"flush batch:{batch_idx}") |
| 102 | + client.flush(collection_name) |
| 103 | + print(f"inserted batch:{batch_idx}") |
| 104 | + |
| 105 | + total_rows = batch_num * num_entities |
| 106 | + print(fmt.format("Quick stats (compact)")) |
| 107 | + print(f"total_rows: {total_rows}") |
| 108 | + print(f"unique (c1,c6) pairs: {len(pair_counter)}") |
| 109 | + print("top 5 (c1,c6) pairs:") |
| 110 | + for (c1v, c6v), cnt in pair_counter.most_common(5): |
| 111 | + print(f" ({c1v}, {c6v}): {cnt}") |
| 112 | + print(f"sum(c2): {c2_sum}, sum(c3): {c3_sum}, sum(c4): {c4_sum:.2f}") |
| 113 | + print(f"c2 < 2: {c2_lt_2} ({c2_lt_2/total_rows*100:.2f}%)") |
| 114 | + if SHOW_STATS_DETAILS: |
| 115 | + print(f"c2 distribution: {dict(sorted(c2_dist.items()))}") |
| 116 | + |
| 117 | + |
| 118 | + print(fmt.format("Start Creating index IVF_FLAT")) |
| 119 | + from pymilvus.milvus_client.index import IndexParams |
| 120 | + index_params = IndexParams() |
| 121 | + index_params.add_index("c5", index_type="IVF_FLAT", metric_type="L2", nlist=128) |
| 122 | + client.create_index(collection_name, index_params) |
| 123 | + |
| 124 | +stats = client.get_collection_stats(collection_name) |
| 125 | +print(f"Number of entities in Milvus: {stats.get('row_count', 0)}") # check the num_entities |
| 126 | +client.load_collection(collection_name) |
| 127 | + |
| 128 | + |
| 129 | +#1. group by TIMESTAMPTZ + max |
| 130 | +print(fmt.format("Query: group by TIMESTAMPTZ + max")) |
| 131 | +res_ts = client.query( |
| 132 | + collection_name=collection_name, |
| 133 | + filter="c2 < 10", |
| 134 | + output_fields=[TS, "count(c2)", f"max({TS})"], |
| 135 | + timeout=120.0, |
| 136 | + group_by_fields=[TS], |
| 137 | +) |
| 138 | +for row in res_ts: |
| 139 | + print(f"res={row}") |
| 140 | + |
| 141 | +#2. group by (c1,c6) + min/max |
| 142 | +print(fmt.format("Query: group by (c1,c6) + min/max")) |
| 143 | +res_minmax = client.query( |
| 144 | + collection_name=collection_name, |
| 145 | + filter="c2 < 10", |
| 146 | + output_fields=["c1", "c6", "min(c2)", "max(c2)"], |
| 147 | + timeout=120.0, |
| 148 | + group_by_fields=["c1", "c6"], |
| 149 | +) |
| 150 | +for row in res_minmax: |
| 151 | + print(f"res={row}") |
| 152 | + |
| 153 | + |
| 154 | +#3. group by c1 + avg(c2, c3, c4) |
| 155 | +print(fmt.format("Query: group by c1 + avg(c2, c3, c4)")) |
| 156 | +res_avg = client.query( |
| 157 | + collection_name=collection_name, |
| 158 | + filter="c2 < 10", |
| 159 | + output_fields=["c1", "avg(c2)", "avg(c3)", "avg(c4)"], |
| 160 | + timeout=120.0, |
| 161 | + group_by_fields=["c1"], |
| 162 | +) |
| 163 | +for row in res_avg: |
| 164 | + print(f"res={row}") |
| 165 | + |
| 166 | +#4. group by c1 + avg(c2, c3, c4) without expr |
| 167 | +print(fmt.format("Query: group by c1 + avg(c2, c3, c4)")) |
| 168 | +res_avg = client.query( |
| 169 | + collection_name=collection_name, |
| 170 | + filter="", |
| 171 | + output_fields=["c1", "avg(c2)", "avg(c3)", "avg(c4)"], |
| 172 | + timeout=120.0, |
| 173 | + limit=10, |
| 174 | + group_by_fields=["c1"], |
| 175 | +) |
| 176 | +for row in res_avg: |
| 177 | + print(f"res={row}") |
0 commit comments