@@ -38,7 +38,7 @@ def close(self):
3838
3939V2Candidate = namedtuple (
4040 "V2Candidate" ,
41- ["uid" , "op_seq" , "shapes" , "dtypes" ],
41+ ["uid" , "sample_type" , " op_seq" , "shapes" , "dtypes" ],
4242)
4343
4444
@@ -76,12 +76,14 @@ def generate_v1_groups(bucket_groups: list[BucketGroup]):
7676 for uid in sampled :
7777 yield uid , _new_group_id (), "rule1"
7878
79- # Rule 2: group heads by op_seq
80- op_seq_to_heads = defaultdict (list )
79+ # Rule 2: group heads by (sample_type, op_seq)
80+ type_op_seq_to_heads = defaultdict (list )
8181 for bucket in bucket_groups :
82- op_seq_to_heads [bucket .op_seq ].append (bucket .head_uid )
82+ type_op_seq_to_heads [(bucket .sample_type , bucket .op_seq )].append (
83+ bucket .head_uid
84+ )
8385
84- for heads in op_seq_to_heads .values ():
86+ for heads in type_op_seq_to_heads .values ():
8587 group_id = _new_group_id ()
8688 for uid in heads :
8789 yield uid , group_id , "rule2"
@@ -150,12 +152,12 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
150152
151153 candidates_by_op_seq = defaultdict (list )
152154 for c in candidates :
153- candidates_by_op_seq [c . op_seq ].append (c )
155+ candidates_by_op_seq [( c . sample_type , c . op_seq ) ].append (c )
154156
155157 dtype_covered_uids = set ()
156158
157159 # --- Rule 4: dtype coverage (runs first) ---
158- for op_seq , op_candidates in candidates_by_op_seq .items ():
160+ for key , op_candidates in candidates_by_op_seq .items ():
159161 candidates_by_shape = defaultdict (list )
160162 for c in op_candidates :
161163 candidates_by_shape [c .shapes ].append (c )
@@ -176,7 +178,7 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
176178
177179 # --- Rule 3: global sparse sampling (on remaining candidates) ---
178180 window_size = num_dtypes * 5
179- for op_seq , op_candidates in candidates_by_op_seq .items ():
181+ for key , op_candidates in candidates_by_op_seq .items ():
180182 remaining = [c for c in op_candidates if c .uid not in dtype_covered_uids ]
181183 remaining .sort (key = lambda c : c .uid )
182184
@@ -195,6 +197,7 @@ def query_v2_candidates(db: DB) -> list[V2Candidate]:
195197 sql = """
196198SELECT
197199 s.uuid,
200+ s.sample_type,
198201 b.op_seq_bucket_id,
199202 b.input_shapes_bucket_id,
200203 b.input_dtypes_bucket_id
@@ -208,7 +211,7 @@ def query_v2_candidates(db: DB) -> list[V2Candidate]:
208211 WHERE g.group_policy = 'bucket_policy_v1'
209212 AND g.deleted = 0
210213 )
211- ORDER BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.uuid;
214+ ORDER BY s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.uuid;
212215 """
213216 return [V2Candidate (* row ) for row in db .query (sql )]
214217
0 commit comments