Skip to content

Commit 01fede0

Browse files
authored
Limit TK conv to f16xf32 (#41)
We doesn't support other datatypes yet. Signed-off-by: Ivan Butygin <[email protected]>
1 parent f8d6ecd commit 01fede0

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

convbench/conv_bench.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import sys
1010
from utils import *
1111
from conv_utils import *
12-
from problems import get_conv_configs, get_conv_test_configs
12+
from problems import get_conv_configs, get_tk_conv_configs, get_conv_test_configs
1313

1414
from wave_conv_utils import compile_wave_conv_config
1515

@@ -56,7 +56,7 @@ def compile_conv_wave(tag, config, kernel_dir, vmfb_dir, extra_compiler_args):
5656
sys.exit()
5757

5858
# configs = get_conv_test_configs()
59-
configs = get_conv_configs()
59+
configs = get_tk_conv_configs() if args.tk else get_conv_configs()
6060
print(f"Generated {len(configs)} conv configs.")
6161

6262
num_cpus = max(1, cpu_count() - 20)

convbench/problems.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,14 @@ def get_conv_configs() -> list[tuple[str, ConvConfig]]:
8181

8282
return configs
8383

84+
def get_tk_conv_configs() -> list[tuple[str, ConvConfig]]:
85+
def check(config_tuple: tuple[str, ConvConfig]) -> bool:
86+
config = config_tuple[1]
87+
return config.input_dtype == "f16" and config.output_dtype == "f32"
88+
89+
return list(filter(check, get_conv_configs()))
90+
91+
8492
# Test function to run only a few chosen shapes
8593
def get_conv_test_configs() -> list[tuple[str, ConvConfig]]:
8694
configs: list[tuple[str, ConvConfig]] = []

0 commit comments

Comments
 (0)