Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
316 changes: 229 additions & 87 deletions sqlite/graph_net_sample_groups_insert.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,57 @@
import argparse
import sqlite3
import uuid as uuid_module
from collections import defaultdict, namedtuple
from datetime import datetime
from collections import namedtuple
from collections import defaultdict

from orm_models import (
get_session,
GraphNetSampleGroup,
from orm_models import get_session, GraphNetSampleGroup


# ── Types ──

BucketGroup = namedtuple(
"BucketGroup",
["head_uid", "op_seq", "shapes", "sample_type", "all_uids_csv"],
)

Candidate = namedtuple(
"Candidate",
["uid", "sample_type", "op_seq", "shapes", "dtypes"],
)


GraphNetSampleUid = str
GraphNetSampleType = str
BucketId = str
# ── Helpers ──


def _new_group_id():
return str(uuid_module.uuid4())


def _merge_stats(dst, src):
for key, val in src.items():
dst[key]["records"] += val["records"]
dst[key]["groups"].update(val["groups"])


def _print_stats(stats):
rule_order = ["rule1", "rule2", "rule4", "rule3"]
sample_types = sorted({st for st, _ in stats})
total_records = 0
total_groups = 0
for sample_type in sample_types:
print(f"\n [{sample_type}]")
for rule in rule_order:
key = (sample_type, rule)
if key in stats:
n_records = stats[key]["records"]
n_groups = len(stats[key]["groups"])
print(f" {rule}: {n_records} records, {n_groups} groups")
total_records += n_records
total_groups += n_groups
print(f"\n Total: {total_records} records, {total_groups} groups.")


# ── Database Queries ──


class DB:
Expand All @@ -23,117 +61,221 @@ def __init__(self, path):
def connect(self):
self.conn = sqlite3.connect(self.path)
self.conn.row_factory = sqlite3.Row
self.cur = self.conn.cursor()
self.cursor = self.conn.cursor()

def query(self, sql, params=None):
self.cur.execute(sql, params or ())
return self.cur.fetchall()

def exec(self, sql, params=None):
self.cur.execute(sql, params or ())
self.conn.commit()
self.cursor.execute(sql, params or ())
return self.cursor.fetchall()

def close(self):
self.conn.close()


SampleBucketInfo = namedtuple(
"SampleBucketInfo",
[
"sample_uid",
"op_seq_bucket_id",
"input_shapes_bucket_id",
"input_dtypes_bucket_id",
"sample_type",
"sample_uids",
],
)
def query_bucket_groups(db: DB) -> list[BucketGroup]:
sql = """
SELECT
sub.sample_uid,
sub.op_seq_bucket_id,
sub.input_shapes_bucket_id,
sub.sample_type,
group_concat(sub.sample_uid, ',') AS all_uids
FROM (
SELECT
s.uuid AS sample_uid,
s.sample_type,
b.op_seq_bucket_id,
b.input_shapes_bucket_id
FROM graph_sample s
JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid
ORDER BY s.create_at ASC, s.uuid ASC
) sub
GROUP BY sub.sample_type, sub.op_seq_bucket_id, sub.input_shapes_bucket_id;
"""
return [BucketGroup(*row) for row in db.query(sql)]


def get_ai4c_group_members(sample_bucket_infos: list[SampleBucketInfo]):
for bucket_info in sample_bucket_infos:
head_sample_uid = bucket_info.sample_uid
sample_uids = bucket_info.sample_uids.split(",")
selected_other_sample_uids = [
other_sample_uid
for other_sample_uid in sample_uids[::5]
if other_sample_uid != head_sample_uid
]
for sample_uid in selected_other_sample_uids:
new_uuid = str(uuid_module.uuid4())
yield sample_uid, new_uuid
def query_v2_candidates(db: DB) -> list[Candidate]:
sql = """
SELECT
s.uuid,
s.sample_type,
b.op_seq_bucket_id,
b.input_shapes_bucket_id,
b.input_dtypes_bucket_id
FROM graph_sample s
JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid
WHERE s.deleted = 0
AND s.sample_type != 'full_graph'
AND s.uuid NOT IN (
SELECT g.sample_uid
FROM graph_net_sample_groups g
WHERE g.group_policy = 'bucket_policy_v1'
AND g.deleted = 0
)
ORDER BY s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id,
b.input_dtypes_bucket_id, s.uuid;
"""
return [Candidate(*row) for row in db.query(sql)]

grouped = defaultdict(list)
for bucket_info in sample_bucket_infos:
key = (bucket_info.op_seq_bucket_id, bucket_info.input_dtypes_bucket_id)
grouped[key].append(bucket_info.sample_uid)

grouped = dict(grouped)
for key, sample_uids in grouped.items():
new_uuid = str(uuid_module.uuid4())
for sample_uid in sample_uids:
yield sample_uid, new_uuid
# ═══════════════════════════════════════════════════════════════════
# V1: Rule 1 (bucket-internal stride sampling) + Rule 2 (cross-shape)
# ═══════════════════════════════════════════════════════════════════


def main():
parser = argparse.ArgumentParser(
description="Generate graph_net_sample_groups from graph_net_sample_buckets"
)
parser.add_argument(
"--db_path",
type=str,
required=True,
help="Path to the SQLite database file",
)
def generate_v1_groups(bucket_groups: list[BucketGroup]):
"""Yields (sample_type, uid, group_id, rule_name).

args = parser.parse_args()
db = DB(args.db_path)
db.connect()
Rule 1: stride-16 sampling within each bucket, one group per sample.
Rule 2: aggregate all bucket heads sharing the same (sample_type, op_seq).
"""
# Rule 1
for bucket in bucket_groups:
members = bucket.all_uids_csv.split(",")
for uid in members[::16]:
if uid != bucket.head_uid:
yield bucket.sample_type, uid, _new_group_id(), "rule1"

query_str = """
SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids
FROM (
SELECT
s.uuid AS sample_uid,
s.sample_type AS sample_type,
b.op_seq_bucket_id AS op_seq_bucket_id,
b.input_shapes_bucket_id AS input_shapes_bucket_id,
b.input_dtypes_bucket_id AS input_dtypes_bucket_id
FROM graph_sample s
JOIN graph_net_sample_buckets b
ON s.uuid = b.sample_uid
order by s.create_at asc, s.uuid asc
) b
GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id,b.input_dtypes_bucket_id;
# Rule 2
heads_by_type_op = defaultdict(list)
for bucket in bucket_groups:
heads_by_type_op[(bucket.sample_type, bucket.op_seq)].append(bucket.head_uid)
for (sample_type, _op), heads in heads_by_type_op.items():
gid = _new_group_id()
for uid in heads:
yield sample_type, uid, gid, "rule2"


# ═══════════════════════════════════════════════════════════════════
# V2: Rule 4 (dtype coverage) + Rule 3 (sparse sampling on remainder)
# ═══════════════════════════════════════════════════════════════════


def generate_v2_groups(candidates: list[Candidate], num_dtypes: int):
"""Yields (sample_type, uid, group_id, rule_name).

Rule 4 (first): per (sample_type, op_seq, shape), pick up to
num_dtypes samples with distinct dtypes.
Rule 3 (second): window-based sparse sampling on the remainder,
window_size = num_dtypes * 5, pick first num_dtypes.
"""
by_type_op = defaultdict(list)
for c in candidates:
by_type_op[(c.sample_type, c.op_seq)].append(c)

query_results = db.query(query_str)
print("Output:", len(query_results))
covered_uids = set()

query_results = [SampleBucketInfo(*row) for row in query_results]
# Rule 4: dtype coverage
for (sample_type, _op), group in by_type_op.items():
by_shape = defaultdict(list)
for c in group:
by_shape[c.shapes].append(c)

session = get_session(args.db_path)
picked = []
for _shape, shape_group in by_shape.items():
seen_dtypes = set()
for c in shape_group:
if c.dtypes not in seen_dtypes and len(seen_dtypes) < num_dtypes:
seen_dtypes.add(c.dtypes)
picked.append(c.uid)
covered_uids.add(c.uid)

try:
for sample_uid, group_uid in get_ai4c_group_members(query_results):
new_group = GraphNetSampleGroup(
sample_uid=sample_uid,
group_uid=group_uid,
if picked:
gid = _new_group_id()
for uid in picked:
yield sample_type, uid, gid, "rule4"

# Rule 3: sparse sampling on remainder
window_size = num_dtypes * 5
for (sample_type, _op), group in by_type_op.items():
remaining = sorted(
(c for c in group if c.uid not in covered_uids),
key=lambda c: c.uid,
)
picked = [
c.uid for i, c in enumerate(remaining) if (i % window_size) < num_dtypes
]
if picked:
gid = _new_group_id()
for uid in picked:
yield sample_type, uid, gid, "rule3"


# ═══════════════════════════════════════════════════════════════════
# Insert
# ═══════════════════════════════════════════════════════════════════


def _insert_groups(session, rows, policy):
"""Consume a generator of (sample_type, uid, group_id, rule_name),
write to DB, and return per-(sample_type, rule) stats."""
stats = defaultdict(lambda: {"records": 0, "groups": set()})
for sample_type, uid, group_id, rule_name in rows:
session.add(
GraphNetSampleGroup(
sample_uid=uid,
group_uid=group_id,
group_type="ai4c",
group_policy="bucket_policy_v1",
group_policy=policy,
policy_version="1.0",
create_at=datetime.now(),
deleted=False,
)
)
stats[(sample_type, rule_name)]["records"] += 1
stats[(sample_type, rule_name)]["groups"].add(group_id)
session.commit()
return stats


session.add(new_group)
session.commit()
# ═══════════════════════════════════════════════════════════════════
# Main
# ═══════════════════════════════════════════════════════════════════


def main():
parser = argparse.ArgumentParser(
description="Generate graph_net_sample_groups (v1 + v2)"
)
parser.add_argument("--db_path", type=str, required=True)
parser.add_argument("--num_dtypes", type=int, default=3)
args = parser.parse_args()

db = DB(args.db_path)
db.connect()
session = get_session(args.db_path)

all_stats = defaultdict(lambda: {"records": 0, "groups": set()})

try:
# V1
buckets = query_bucket_groups(db)
print(f"Bucket groups: {len(buckets)}")
v1 = _insert_groups(session, generate_v1_groups(buckets), "bucket_policy_v1")
_merge_stats(all_stats, v1)

# V2
candidates = query_v2_candidates(db)
print(f"V2 candidates: {len(candidates)}")
if candidates:
v2 = _insert_groups(
session,
generate_v2_groups(candidates, args.num_dtypes),
"bucket_policy_v2",
)
_merge_stats(all_stats, v2)
else:
print("No V2 candidates found. Skipping.")
except Exception:
session.rollback()
raise
finally:
session.close()
db.close()

print("=" * 60)
_print_stats(all_stats)
print("\nDone!")


if __name__ == "__main__":
Expand Down
Loading