Skip to content

Commit dc212ab

Browse files
authored
Define DATA_TYPES_ATTENTION if class AttentionConfiguration is reused outside (#1897)
1 parent baf96d9 commit dc212ab

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

mlir/utils/performance/perfRunner.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1233,6 +1233,8 @@ class AttentionConfiguration(PerfConfiguration):
12331233
TABLE_COLUMNS = reportUtils.ATTN_TEST_PARAMETERS + ['TFlops']
12341234
def __init__(self, dtype: str, g: int, seq_len_q: int, seq_len_k: int, num_heads_q: int, num_heads_kv: int, head_dim_qk: int, head_dim_v: int, with_attn_scale: bool, with_attn_bias: bool,
12351235
transQ: bool, transK: bool, transV: bool, transO: bool, causal: bool, return_lse: bool, arch: str, numCU: int, perf_config: str = ''):
1236+
if DATA_TYPES_ATTENTION is None:
1237+
initializeDataTypesAttention()
12361238
if dtype not in DATA_TYPES_ATTENTION:
12371239
raise ValueError(f"Invalid datatype for a: {dtype}")
12381240

0 commit comments

Comments
 (0)