Skip to content

Commit 05e5d0f

Browse files
committed
group by all
1 parent 489af43 commit 05e5d0f

File tree

1 file changed

+50
-42
lines changed

1 file changed

+50
-42
lines changed

sqlite/graph_net_sample_groups_insert.py

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,22 @@ def _new_group_id():
4646
return str(uuid_module.uuid4())
4747

4848

49-
def _print_stats(stats, rule_order):
49+
def _print_stats(stats):
50+
rule_order = ["rule1", "rule2", "rule4", "rule3"]
51+
sample_types = sorted({st for st, _ in stats.keys()})
5052
total_records = 0
5153
total_groups = 0
52-
for rule_name in rule_order:
53-
if rule_name in stats:
54-
r = stats[rule_name]["records"]
55-
g = len(stats[rule_name]["groups"])
56-
print(f" {rule_name}: {r} records, {g} groups")
57-
total_records += r
58-
total_groups += g
59-
print(f" Total: {total_records} records, {total_groups} groups.")
54+
for sample_type in sample_types:
55+
print(f"\n [{sample_type}]")
56+
for rule_name in rule_order:
57+
key = (sample_type, rule_name)
58+
if key in stats:
59+
r = stats[key]["records"]
60+
g = len(stats[key]["groups"])
61+
print(f" {rule_name}: {r} records, {g} groups")
62+
total_records += r
63+
total_groups += g
64+
print(f"\n Total: {total_records} records, {total_groups} groups.")
6065

6166

6267
# ═══════════════════════════════════════════════════════════════════
@@ -67,14 +72,14 @@ def _print_stats(stats, rule_order):
6772
def generate_v1_groups(bucket_groups: list[BucketGroup]):
6873
"""Rule 1: stride-5 sampling, each sampled uid gets its own group.
6974
Rule 2: group all bucket heads that share the same op_seq.
70-
Yields (uid, group_id, rule_name)."""
75+
Yields (sample_type, uid, group_id, rule_name)."""
7176

7277
# Rule 1: stride-5 sampling
7378
for bucket in bucket_groups:
7479
members = bucket.all_uids_csv.split(",")
75-
sampled = [uid for uid in members[::5] if uid != bucket.head_uid]
80+
sampled = [uid for uid in members[::16] if uid != bucket.head_uid]
7681
for uid in sampled:
77-
yield uid, _new_group_id(), "rule1"
82+
yield bucket.sample_type, uid, _new_group_id(), "rule1"
7883

7984
# Rule 2: group heads by (sample_type, op_seq)
8085
type_op_seq_to_heads = defaultdict(list)
@@ -83,10 +88,10 @@ def generate_v1_groups(bucket_groups: list[BucketGroup]):
8388
bucket.head_uid
8489
)
8590

86-
for heads in type_op_seq_to_heads.values():
91+
for (sample_type, _), heads in type_op_seq_to_heads.items():
8792
group_id = _new_group_id()
8893
for uid in heads:
89-
yield uid, group_id, "rule2"
94+
yield sample_type, uid, group_id, "rule2"
9095

9196

9297
def query_v1_bucket_groups(db: DB) -> list[BucketGroup]:
@@ -113,15 +118,11 @@ def query_v1_bucket_groups(db: DB) -> list[BucketGroup]:
113118

114119

115120
def insert_v1_groups(db: DB, session):
116-
print("=" * 60)
117-
print("V1: Rule 1 (stride-5) + Rule 2 (cross-shape aggregation)")
118-
print("=" * 60)
119-
120121
bucket_groups = query_v1_bucket_groups(db)
121-
print(f" Bucket groups: {len(bucket_groups)}")
122+
print(f"Bucket groups: {len(bucket_groups)}")
122123

123124
stats = defaultdict(lambda: {"records": 0, "groups": set()})
124-
for uid, group_id, rule_name in generate_v1_groups(bucket_groups):
125+
for sample_type, uid, group_id, rule_name in generate_v1_groups(bucket_groups):
125126
session.add(
126127
GraphNetSampleGroup(
127128
sample_uid=uid,
@@ -133,11 +134,11 @@ def insert_v1_groups(db: DB, session):
133134
deleted=False,
134135
)
135136
)
136-
stats[rule_name]["records"] += 1
137-
stats[rule_name]["groups"].add(group_id)
137+
stats[(sample_type, rule_name)]["records"] += 1
138+
stats[(sample_type, rule_name)]["groups"].add(group_id)
138139

139140
session.commit()
140-
_print_stats(stats, ["rule1", "rule2"])
141+
return stats
141142

142143

143144
# ═══════════════════════════════════════════════════════════════════
@@ -148,7 +149,7 @@ def insert_v1_groups(db: DB, session):
148149
def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
149150
"""Rule 4 runs first to ensure full dtype coverage.
150151
Rule 3 then sparse-samples from the remaining candidates.
151-
Yields (uid, group_id, rule_name)."""
152+
Yields (sample_type, uid, group_id, rule_name)."""
152153

153154
candidates_by_op_seq = defaultdict(list)
154155
for c in candidates:
@@ -157,7 +158,7 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
157158
dtype_covered_uids = set()
158159

159160
# --- Rule 4: dtype coverage (runs first) ---
160-
for key, op_candidates in candidates_by_op_seq.items():
161+
for (sample_type, _), op_candidates in candidates_by_op_seq.items():
161162
candidates_by_shape = defaultdict(list)
162163
for c in op_candidates:
163164
candidates_by_shape[c.shapes].append(c)
@@ -174,11 +175,11 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
174175
if selected_uids:
175176
group_id = _new_group_id()
176177
for uid in selected_uids:
177-
yield uid, group_id, "rule4"
178+
yield sample_type, uid, group_id, "rule4"
178179

179180
# --- Rule 3: global sparse sampling (on remaining candidates) ---
180181
window_size = num_dtypes * 5
181-
for key, op_candidates in candidates_by_op_seq.items():
182+
for (sample_type, _), op_candidates in candidates_by_op_seq.items():
182183
remaining = [c for c in op_candidates if c.uid not in dtype_covered_uids]
183184
remaining.sort(key=lambda c: c.uid)
184185

@@ -190,7 +191,7 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
190191
if selected_uids:
191192
group_id = _new_group_id()
192193
for uid in selected_uids:
193-
yield uid, group_id, "rule3"
194+
yield sample_type, uid, group_id, "rule3"
194195

195196

196197
def query_v2_candidates(db: DB) -> list[V2Candidate]:
@@ -217,19 +218,17 @@ def query_v2_candidates(db: DB) -> list[V2Candidate]:
217218

218219

219220
def insert_v2_groups(db: DB, session, num_dtypes: int):
220-
print("=" * 60)
221-
print("V2: Rule 4 (dtype coverage) + Rule 3 (sparse sampling)")
222-
print("=" * 60)
223-
224221
candidates = query_v2_candidates(db)
225-
print(f" V2 candidates: {len(candidates)}")
222+
print(f"V2 candidates: {len(candidates)}")
226223

224+
stats = defaultdict(lambda: {"records": 0, "groups": set()})
227225
if not candidates:
228-
print(" No v2 candidates found. Skipping.")
229-
return
226+
print("No v2 candidates found. Skipping.")
227+
return stats
230228

231-
stats = defaultdict(lambda: {"records": 0, "groups": set()})
232-
for uid, group_id, rule_name in generate_v2_groups(candidates, num_dtypes):
229+
for sample_type, uid, group_id, rule_name in generate_v2_groups(
230+
candidates, num_dtypes
231+
):
233232
session.add(
234233
GraphNetSampleGroup(
235234
sample_uid=uid,
@@ -241,11 +240,11 @@ def insert_v2_groups(db: DB, session, num_dtypes: int):
241240
deleted=False,
242241
)
243242
)
244-
stats[rule_name]["records"] += 1
245-
stats[rule_name]["groups"].add(group_id)
243+
stats[(sample_type, rule_name)]["records"] += 1
244+
stats[(sample_type, rule_name)]["groups"].add(group_id)
246245

247246
session.commit()
248-
_print_stats(stats, ["rule4", "rule3"])
247+
return stats
249248

250249

251250
# ═══════════════════════════════════════════════════════════════════
@@ -276,15 +275,24 @@ def main():
276275
session = get_session(args.db_path)
277276

278277
try:
279-
insert_v1_groups(db, session)
280-
insert_v2_groups(db, session, args.num_dtypes)
278+
v1_stats = insert_v1_groups(db, session)
279+
v2_stats = insert_v2_groups(db, session, args.num_dtypes)
281280
except Exception:
282281
session.rollback()
283282
raise
284283
finally:
285284
session.close()
286285
db.close()
287286

287+
# Merge and print
288+
all_stats = defaultdict(lambda: {"records": 0, "groups": set()})
289+
for s in (v1_stats, v2_stats):
290+
for key, val in s.items():
291+
all_stats[key]["records"] += val["records"]
292+
all_stats[key]["groups"].update(val["groups"])
293+
294+
print("=" * 60)
295+
_print_stats(all_stats)
288296
print("\nDone!")
289297

290298

0 commit comments

Comments
 (0)