Skip to content

Commit 7435a0d

Browse files
committed
clean up
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent e6b4ddc commit 7435a0d

File tree

3 files changed

+2
-42
lines changed

3 files changed

+2
-42
lines changed

tritonbench/operators/gdpa/gdpa.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1952,7 +1952,6 @@ def alloc_fn(size: int, alignment: int, stream: int | None):
19521952
ad_to_request_offset = create_dummy_tensor(query)
19531953

19541954
activation_enum_int = activation_string_to_int(activation)
1955-
# print("activation_enum_int", activation, activation_enum_int)
19561955
kernel_info = capture_triton(kernel_fn)[grid](
19571956
q,
19581957
query_offset,

tritonbench/operators/gdpa/gdpa_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from functools import lru_cache
66
from typing import Any, List, Optional
77

8+
# need this for OSS
89
import fbgemm_gpu
910

1011
import torch

tritonbench/operators/gdpa/operator.py

Lines changed: 1 addition & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -84,28 +84,6 @@ def get_attn_config(config_name, dtype=torch.bfloat16):
8484
return default_config
8585

8686

87-
def get_cutlass_config(dtype=torch.bfloat16):
88-
default_config = {
89-
"B": 1152,
90-
"max_M": 1000,
91-
"D": 512,
92-
"H": 4,
93-
"dense_q_len": 192,
94-
"sparsity": 1.0,
95-
"dense_q": False,
96-
"dff": None,
97-
"bias": False,
98-
"dtype": dtype,
99-
"fused_kv": False,
100-
"window_size": None,
101-
"broadcast_q": False,
102-
"activation": "fast_gelu",
103-
}
104-
# per event pffn, pma, self_attn share the same setting
105-
106-
return default_config
107-
108-
10987
all_configs = [
11088
"_".join([event_size, attn_type])
11189
for event_size in ["long_event", "short_event"]
@@ -323,8 +301,7 @@ def _inner():
323301

324302
def get_input_iter(self) -> Generator:
325303
for config_name in self.config_names:
326-
config = get_cutlass_config(self.dtype)
327-
# config = get_attn_config(config_name, self.dtype)
304+
config = get_attn_config(config_name, self.dtype)
328305
B = self.batch
329306
max_M = self.max_seq_len
330307
D = self.dim
@@ -433,23 +410,6 @@ def gbps(
433410
memory_bandwidth_gb_per_sec = memory_size_gb / (ms * 1e-3)
434411
return memory_bandwidth_gb_per_sec
435412

436-
@register_metric()
437-
def flops(
438-
self, fn_name: str, example_inputs: Any, metrics: BenchmarkOperatorMetrics
439-
) -> float:
440-
B = self.batch
441-
max_M = self.max_seq_len
442-
D = self.dim
443-
H = self.head
444-
config = get_cutlass_config(self.dtype)
445-
sparsity = config["sparsity"]
446-
447-
print("D/dim", D) # D/self.dim, assume H * dim in script is D
448-
total_flops = 4 * B * max_M * sparsity * D * D # H * self.dim
449-
# ms = metrics.latency
450-
# print(f"TFLOP/s: {total_flops / 1e9 / ms :.2f}")
451-
return total_flops
452-
453413
@register_metric()
454414
def activation_mb(
455415
self, fn: Callable, example_inputs: Any, metrics: BenchmarkOperatorMetrics

0 commit comments

Comments
 (0)