Skip to content

Commit 2440bd0

Browse files
create demo
1 parent 40381f5 commit 2440bd0

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

examples/verify_algo.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -140,14 +140,16 @@ def verify_algos(
140140
else:
141141
unique_result[item['query']] = max(item["score"], unique_result[item['query']])
142142

143-
144-
llm_cfg = AutoConfig.from_pretrained(model_name)
145-
flow = vortex_torch.flow.build_vflow(vortex_module_name)
146-
memory_access_runtime = flow.run_indexer_virtual(
147-
group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads,
148-
page_size=page_size,
149-
head_dim=llm_cfg.head_dim,
150-
)
143+
if sparse_attention:
144+
llm_cfg = AutoConfig.from_pretrained(model_name)
145+
flow = vortex_torch.flow.build_vflow(vortex_module_name)
146+
memory_access_runtime = flow.run_indexer_virtual(
147+
group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads,
148+
page_size=page_size,
149+
head_dim=llm_cfg.head_dim,
150+
)
151+
else:
152+
memory_access_runtime = 0.0
151153

152154
global_summary = {
153155
f'mean@{trials}': total_accuracy / count if count > 0 else 0,
@@ -156,7 +158,7 @@ def verify_algos(
156158
"e2e_time": e2e_time,
157159
"total_tokens": total_tokens,
158160
"throughput": total_tokens / e2e_time,
159-
"memory_access_runtime (per page)": memory_access_runtime
161+
"auxilary memory_access_runtime (bytes per page)": memory_access_runtime
160162
}
161163

162164
return global_summary
@@ -229,3 +231,4 @@ def parse_args():
229231
)
230232
print(summary)
231233

234+
exit(0)

examples/verify_algo.sh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@
22
set -e
33

44
sparse_algos=(
5-
"gqa_value_aware_sparse_attention"
6-
# "gqa_block_sparse_attention"
7-
# "gqa_quest_sparse_attention"
8-
# "block_sparse_attention"
5+
96
)
107

118
for algo in "${sparse_algos[@]}"; do
@@ -15,6 +12,6 @@ for algo in "${sparse_algos[@]}"; do
1512
--topk-val 30 \
1613
--vortex-module-name "${algo}" \
1714
--model-name Qwen/Qwen3-1.7B \
18-
--mem 0.8
15+
--mem 0.7
1916
done
2017

vortex_torch/abs/context_base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def summary(self) -> None:
113113
Print fields; tensor fields show shape/dtype/device, and append memory totals incl. auxiliary.
114114
"""
115115

116+
return
116117
def _fmt_bytes(n: int) -> str:
117118
units = ("B", "KB", "MB", "GB", "TB", "PB")
118119
f = float(n)

0 commit comments

Comments
 (0)