Skip to content

Commit ad13e93

Browse files
committed
simplify preset implementation and fix normal preset
1 parent 1c7b189 commit ad13e93

File tree

3 files changed

+30
-33
lines changed

3 files changed

+30
-33
lines changed

devops/scripts/benchmarks/main.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from history import BenchmarkHistory
1818
from utils.utils import prepare_workdir
1919
from utils.compute_runtime import *
20-
from presets import Presets
20+
from presets import preset_get_by_name, presets
2121

2222
import argparse
2323
import re
@@ -440,9 +440,9 @@ def validate_and_parse_env_args(env_args):
440440
parser.add_argument(
441441
"--preset",
442442
type=str,
443-
choices=[p.name for p in Presets],
443+
choices=[p.name() for p in presets],
444444
help="Benchmark preset to run.",
445-
default="FULL",
445+
default=options.preset.name(),
446446
)
447447

448448
args = parser.parse_args()
@@ -469,7 +469,7 @@ def validate_and_parse_env_args(env_args):
469469
options.current_run_name = args.relative_perf
470470
options.cudnn_directory = args.cudnn_directory
471471
options.cublas_directory = args.cublas_directory
472-
options.preset = Presets[args.preset].value()
472+
options.preset = preset_get_by_name(args.preset)
473473

474474
if args.build_igc and args.compute_runtime is None:
475475
parser.error("--build-igc requires --compute-runtime to be set")

devops/scripts/benchmarks/options.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass, field
22
from enum import Enum
3-
from presets import Preset
3+
from presets import Preset, presets
44

55

66
class Compare(Enum):
@@ -40,7 +40,7 @@ class Options:
4040
compute_runtime_tag: str = "25.05.32567.18"
4141
build_igc: bool = False
4242
current_run_name: str = "This PR"
43-
preset: Preset = None
43+
preset: Preset = presets[0]
4444

4545

4646
options = Options()

devops/scripts/benchmarks/presets.py

Lines changed: 24 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,23 @@
33
# See LICENSE.TXT
44
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
55

6-
from enum import Enum
7-
6+
from typing import List, Type
87

98
class Preset:
10-
def description(self):
11-
pass
9+
def description(self) -> str:
10+
raise NotImplementedError
1211

13-
def suites(self) -> list[str]:
14-
return []
12+
def name(self) -> str:
13+
return self.__class__.__name__
1514

15+
def suites(self) -> List[str]:
16+
raise NotImplementedError
1617

1718
class Full(Preset):
18-
def description(self):
19+
def description(self) -> str:
1920
return "All available benchmarks."
2021

21-
def suites(self) -> list[str]:
22+
def suites(self) -> List[str]:
2223
return [
2324
"Compute Benchmarks",
2425
"llama.cpp bench",
@@ -27,42 +28,38 @@ def suites(self) -> list[str]:
2728
"UMF",
2829
]
2930

30-
3131
class SYCL(Preset):
32-
def description(self):
32+
def description(self) -> str:
3333
return "All available benchmarks related to SYCL."
3434

35-
def suites(self) -> list[str]:
35+
def suites(self) -> List[str]:
3636
return ["Compute Benchmarks", "llama.cpp bench", "SYCL-Bench", "Velocity Bench"]
3737

38-
3938
class Minimal(Preset):
40-
def description(self):
39+
def description(self) -> str:
4140
return "Short microbenchmarks."
4241

43-
def suites(self) -> list[str]:
42+
def suites(self) -> List[str]:
4443
return ["Compute Benchmarks"]
4544

46-
4745
class Normal(Preset):
48-
def description(self):
46+
def description(self) -> str:
4947
return "Comprehensive mix of microbenchmarks and real applications."
5048

51-
def suites(self) -> list[str]:
52-
return ["Compute Benchmarks"]
53-
49+
def suites(self) -> List[str]:
50+
return ["Compute Benchmarks", "llama.cpp bench", "Velocity Bench"]
5451

5552
class Test(Preset):
56-
def description(self):
53+
def description(self) -> str:
5754
return "Noop benchmarks for framework testing."
5855

59-
def suites(self) -> list[str]:
56+
def suites(self) -> List[str]:
6057
return ["Test Suite"]
6158

59+
presets = [Full(), SYCL(), Minimal(), Normal(), Test()]
6260

63-
class Presets(Enum):
64-
FULL = Full
65-
SYCL = SYCL # Nightly
66-
NORMAL = Normal # PR
67-
MINIMAL = Minimal # Quick smoke tests
68-
TEST = Test
61+
def preset_get_by_name(name: str) -> Preset:
62+
for p in presets:
63+
if p.name().upper() == name.upper():
64+
return p
65+
raise ValueError(f"Preset '{name}' not found.")

0 commit comments

Comments
 (0)