@@ -46,17 +46,22 @@ def _new_group_id():
4646 return str (uuid_module .uuid4 ())
4747
4848
49- def _print_stats (stats , rule_order ):
49+ def _print_stats (stats ):
50+ rule_order = ["rule1" , "rule2" , "rule4" , "rule3" ]
51+ sample_types = sorted ({st for st , _ in stats .keys ()})
5052 total_records = 0
5153 total_groups = 0
52- for rule_name in rule_order :
53- if rule_name in stats :
54- r = stats [rule_name ]["records" ]
55- g = len (stats [rule_name ]["groups" ])
56- print (f" { rule_name } : { r } records, { g } groups" )
57- total_records += r
58- total_groups += g
59- print (f" Total: { total_records } records, { total_groups } groups." )
54+ for sample_type in sample_types :
55+ print (f"\n [{ sample_type } ]" )
56+ for rule_name in rule_order :
57+ key = (sample_type , rule_name )
58+ if key in stats :
59+ r = stats [key ]["records" ]
60+ g = len (stats [key ]["groups" ])
61+ print (f" { rule_name } : { r } records, { g } groups" )
62+ total_records += r
63+ total_groups += g
64+ print (f"\n Total: { total_records } records, { total_groups } groups." )
6065
6166
6267# ═══════════════════════════════════════════════════════════════════
@@ -67,14 +72,14 @@ def _print_stats(stats, rule_order):
6772def generate_v1_groups (bucket_groups : list [BucketGroup ]):
6873 """Rule 1: stride-5 sampling, each sampled uid gets its own group.
6974 Rule 2: group all bucket heads that share the same op_seq.
70- Yields (uid, group_id, rule_name)."""
75+ Yields (sample_type, uid, group_id, rule_name)."""
7176
7277 # Rule 1: stride-5 sampling
7378 for bucket in bucket_groups :
7479 members = bucket .all_uids_csv .split ("," )
75- sampled = [uid for uid in members [::5 ] if uid != bucket .head_uid ]
80+ sampled = [uid for uid in members [::16 ] if uid != bucket .head_uid ]
7681 for uid in sampled :
77- yield uid , _new_group_id (), "rule1"
82+ yield bucket . sample_type , uid , _new_group_id (), "rule1"
7883
7984 # Rule 2: group heads by (sample_type, op_seq)
8085 type_op_seq_to_heads = defaultdict (list )
@@ -83,10 +88,10 @@ def generate_v1_groups(bucket_groups: list[BucketGroup]):
8388 bucket .head_uid
8489 )
8590
86- for heads in type_op_seq_to_heads .values ():
91+ for ( sample_type , _ ), heads in type_op_seq_to_heads .items ():
8792 group_id = _new_group_id ()
8893 for uid in heads :
89- yield uid , group_id , "rule2"
94+ yield sample_type , uid , group_id , "rule2"
9095
9196
9297def query_v1_bucket_groups (db : DB ) -> list [BucketGroup ]:
@@ -113,15 +118,11 @@ def query_v1_bucket_groups(db: DB) -> list[BucketGroup]:
113118
114119
115120def insert_v1_groups (db : DB , session ):
116- print ("=" * 60 )
117- print ("V1: Rule 1 (stride-5) + Rule 2 (cross-shape aggregation)" )
118- print ("=" * 60 )
119-
120121 bucket_groups = query_v1_bucket_groups (db )
121- print (f" Bucket groups: { len (bucket_groups )} " )
122+ print (f"Bucket groups: { len (bucket_groups )} " )
122123
123124 stats = defaultdict (lambda : {"records" : 0 , "groups" : set ()})
124- for uid , group_id , rule_name in generate_v1_groups (bucket_groups ):
125+ for sample_type , uid , group_id , rule_name in generate_v1_groups (bucket_groups ):
125126 session .add (
126127 GraphNetSampleGroup (
127128 sample_uid = uid ,
@@ -133,11 +134,11 @@ def insert_v1_groups(db: DB, session):
133134 deleted = False ,
134135 )
135136 )
136- stats [rule_name ]["records" ] += 1
137- stats [rule_name ]["groups" ].add (group_id )
137+ stats [( sample_type , rule_name ) ]["records" ] += 1
138+ stats [( sample_type , rule_name ) ]["groups" ].add (group_id )
138139
139140 session .commit ()
140- _print_stats ( stats , [ "rule1" , "rule2" ])
141+ return stats
141142
142143
143144# ═══════════════════════════════════════════════════════════════════
@@ -148,7 +149,7 @@ def insert_v1_groups(db: DB, session):
148149def generate_v2_groups (candidates : list [V2Candidate ], num_dtypes : int ):
149150 """Rule 4 runs first to ensure full dtype coverage.
150151 Rule 3 then sparse-samples from the remaining candidates.
151- Yields (uid, group_id, rule_name)."""
152+ Yields (sample_type, uid, group_id, rule_name)."""
152153
153154 candidates_by_op_seq = defaultdict (list )
154155 for c in candidates :
@@ -157,7 +158,7 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
157158 dtype_covered_uids = set ()
158159
159160 # --- Rule 4: dtype coverage (runs first) ---
160- for key , op_candidates in candidates_by_op_seq .items ():
161+ for ( sample_type , _ ) , op_candidates in candidates_by_op_seq .items ():
161162 candidates_by_shape = defaultdict (list )
162163 for c in op_candidates :
163164 candidates_by_shape [c .shapes ].append (c )
@@ -174,11 +175,11 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
174175 if selected_uids :
175176 group_id = _new_group_id ()
176177 for uid in selected_uids :
177- yield uid , group_id , "rule4"
178+ yield sample_type , uid , group_id , "rule4"
178179
179180 # --- Rule 3: global sparse sampling (on remaining candidates) ---
180181 window_size = num_dtypes * 5
181- for key , op_candidates in candidates_by_op_seq .items ():
182+ for ( sample_type , _ ) , op_candidates in candidates_by_op_seq .items ():
182183 remaining = [c for c in op_candidates if c .uid not in dtype_covered_uids ]
183184 remaining .sort (key = lambda c : c .uid )
184185
@@ -190,7 +191,7 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
190191 if selected_uids :
191192 group_id = _new_group_id ()
192193 for uid in selected_uids :
193- yield uid , group_id , "rule3"
194+ yield sample_type , uid , group_id , "rule3"
194195
195196
196197def query_v2_candidates (db : DB ) -> list [V2Candidate ]:
@@ -217,19 +218,17 @@ def query_v2_candidates(db: DB) -> list[V2Candidate]:
217218
218219
219220def insert_v2_groups (db : DB , session , num_dtypes : int ):
220- print ("=" * 60 )
221- print ("V2: Rule 4 (dtype coverage) + Rule 3 (sparse sampling)" )
222- print ("=" * 60 )
223-
224221 candidates = query_v2_candidates (db )
225- print (f" V2 candidates: { len (candidates )} " )
222+ print (f"V2 candidates: { len (candidates )} " )
226223
224+ stats = defaultdict (lambda : {"records" : 0 , "groups" : set ()})
227225 if not candidates :
228- print (" No v2 candidates found. Skipping." )
229- return
226+ print ("No v2 candidates found. Skipping." )
227+ return stats
230228
231- stats = defaultdict (lambda : {"records" : 0 , "groups" : set ()})
232- for uid , group_id , rule_name in generate_v2_groups (candidates , num_dtypes ):
229+ for sample_type , uid , group_id , rule_name in generate_v2_groups (
230+ candidates , num_dtypes
231+ ):
233232 session .add (
234233 GraphNetSampleGroup (
235234 sample_uid = uid ,
@@ -241,11 +240,11 @@ def insert_v2_groups(db: DB, session, num_dtypes: int):
241240 deleted = False ,
242241 )
243242 )
244- stats [rule_name ]["records" ] += 1
245- stats [rule_name ]["groups" ].add (group_id )
243+ stats [( sample_type , rule_name ) ]["records" ] += 1
244+ stats [( sample_type , rule_name ) ]["groups" ].add (group_id )
246245
247246 session .commit ()
248- _print_stats ( stats , [ "rule4" , "rule3" ])
247+ return stats
249248
250249
251250# ═══════════════════════════════════════════════════════════════════
@@ -276,15 +275,24 @@ def main():
276275 session = get_session (args .db_path )
277276
278277 try :
279- insert_v1_groups (db , session )
280- insert_v2_groups (db , session , args .num_dtypes )
278+ v1_stats = insert_v1_groups (db , session )
279+ v2_stats = insert_v2_groups (db , session , args .num_dtypes )
281280 except Exception :
282281 session .rollback ()
283282 raise
284283 finally :
285284 session .close ()
286285 db .close ()
287286
287+ # Merge and print
288+ all_stats = defaultdict (lambda : {"records" : 0 , "groups" : set ()})
289+ for s in (v1_stats , v2_stats ):
290+ for key , val in s .items ():
291+ all_stats [key ]["records" ] += val ["records" ]
292+ all_stats [key ]["groups" ].update (val ["groups" ])
293+
294+ print ("=" * 60 )
295+ _print_stats (all_stats )
288296 print ("\n Done!" )
289297
290298
0 commit comments