Skip to content

Commit 489af43

Browse files
committed
group by all
1 parent 238d966 commit 489af43

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

sqlite/graph_net_sample_groups_insert.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def close(self):
3838

3939
V2Candidate = namedtuple(
4040
"V2Candidate",
41-
["uid", "op_seq", "shapes", "dtypes"],
41+
["uid", "sample_type", "op_seq", "shapes", "dtypes"],
4242
)
4343

4444

@@ -76,12 +76,14 @@ def generate_v1_groups(bucket_groups: list[BucketGroup]):
7676
for uid in sampled:
7777
yield uid, _new_group_id(), "rule1"
7878

79-
# Rule 2: group heads by op_seq
80-
op_seq_to_heads = defaultdict(list)
79+
# Rule 2: group heads by (sample_type, op_seq)
80+
type_op_seq_to_heads = defaultdict(list)
8181
for bucket in bucket_groups:
82-
op_seq_to_heads[bucket.op_seq].append(bucket.head_uid)
82+
type_op_seq_to_heads[(bucket.sample_type, bucket.op_seq)].append(
83+
bucket.head_uid
84+
)
8385

84-
for heads in op_seq_to_heads.values():
86+
for heads in type_op_seq_to_heads.values():
8587
group_id = _new_group_id()
8688
for uid in heads:
8789
yield uid, group_id, "rule2"
@@ -150,12 +152,12 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
150152

151153
candidates_by_op_seq = defaultdict(list)
152154
for c in candidates:
153-
candidates_by_op_seq[c.op_seq].append(c)
155+
candidates_by_op_seq[(c.sample_type, c.op_seq)].append(c)
154156

155157
dtype_covered_uids = set()
156158

157159
# --- Rule 4: dtype coverage (runs first) ---
158-
for op_seq, op_candidates in candidates_by_op_seq.items():
160+
for key, op_candidates in candidates_by_op_seq.items():
159161
candidates_by_shape = defaultdict(list)
160162
for c in op_candidates:
161163
candidates_by_shape[c.shapes].append(c)
@@ -176,7 +178,7 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int):
176178

177179
# --- Rule 3: global sparse sampling (on remaining candidates) ---
178180
window_size = num_dtypes * 5
179-
for op_seq, op_candidates in candidates_by_op_seq.items():
181+
for key, op_candidates in candidates_by_op_seq.items():
180182
remaining = [c for c in op_candidates if c.uid not in dtype_covered_uids]
181183
remaining.sort(key=lambda c: c.uid)
182184

@@ -195,6 +197,7 @@ def query_v2_candidates(db: DB) -> list[V2Candidate]:
195197
sql = """
196198
SELECT
197199
s.uuid,
200+
s.sample_type,
198201
b.op_seq_bucket_id,
199202
b.input_shapes_bucket_id,
200203
b.input_dtypes_bucket_id
@@ -208,7 +211,7 @@ def query_v2_candidates(db: DB) -> list[V2Candidate]:
208211
WHERE g.group_policy = 'bucket_policy_v1'
209212
AND g.deleted = 0
210213
)
211-
ORDER BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.uuid;
214+
ORDER BY s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.uuid;
212215
"""
213216
return [V2Candidate(*row) for row in db.query(sql)]
214217

0 commit comments

Comments
 (0)