@@ -84,28 +84,6 @@ def get_attn_config(config_name, dtype=torch.bfloat16):
84
84
return default_config
85
85
86
86
87
- def get_cutlass_config (dtype = torch .bfloat16 ):
88
- default_config = {
89
- "B" : 1152 ,
90
- "max_M" : 1000 ,
91
- "D" : 512 ,
92
- "H" : 4 ,
93
- "dense_q_len" : 192 ,
94
- "sparsity" : 1.0 ,
95
- "dense_q" : False ,
96
- "dff" : None ,
97
- "bias" : False ,
98
- "dtype" : dtype ,
99
- "fused_kv" : False ,
100
- "window_size" : None ,
101
- "broadcast_q" : False ,
102
- "activation" : "fast_gelu" ,
103
- }
104
- # per event pffn, pma, self_attn share the same setting
105
-
106
- return default_config
107
-
108
-
109
87
all_configs = [
110
88
"_" .join ([event_size , attn_type ])
111
89
for event_size in ["long_event" , "short_event" ]
@@ -175,7 +153,7 @@ def __init__(
175
153
):
176
154
super ().__init__ (tb_args , extra_args = extra_args )
177
155
args = parse_args (self .extra_args )
178
- self .config_names = [ "default" ] # args.config.split(",")
156
+ self .config_names = args .config .split ("," )
179
157
self .sparsity = args .sparsity
180
158
self .batch = args .batch
181
159
self .max_seq_len = args .max_seq_len
@@ -323,8 +301,7 @@ def _inner():
323
301
324
302
def get_input_iter (self ) -> Generator :
325
303
for config_name in self .config_names :
326
- config = get_cutlass_config (self .dtype )
327
- # config = get_attn_config(config_name, self.dtype)
304
+ config = get_attn_config (config_name , self .dtype )
328
305
B = self .batch
329
306
max_M = self .max_seq_len
330
307
D = self .dim
@@ -433,23 +410,6 @@ def gbps(
433
410
memory_bandwidth_gb_per_sec = memory_size_gb / (ms * 1e-3 )
434
411
return memory_bandwidth_gb_per_sec
435
412
436
- @register_metric ()
437
- def flops (
438
- self , fn_name : str , example_inputs : Any , metrics : BenchmarkOperatorMetrics
439
- ) -> float :
440
- B = self .batch
441
- max_M = self .max_seq_len
442
- D = self .dim
443
- H = self .head
444
- config = get_cutlass_config (self .dtype )
445
- sparsity = config ["sparsity" ]
446
-
447
- print ("D/dim" , D ) # D/self.dim, assume H * dim in script is D
448
- total_flops = 4 * B * max_M * sparsity * D * D # H * self.dim
449
- # ms = metrics.latency
450
- # print(f"TFLOP/s: {total_flops / 1e9 / ms :.2f}")
451
- return total_flops
452
-
453
413
@register_metric ()
454
414
def activation_mb (
455
415
self , fn : Callable , example_inputs : Any , metrics : BenchmarkOperatorMetrics
0 commit comments