Skip to content

Commit 5316d96

Browse files
committed
infra and config changes for controlling dtype precision
1 parent 69ff7d9 commit 5316d96

File tree

7 files changed

+223
-29
lines changed

7 files changed

+223
-29
lines changed

dpbench/config/benchmark.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ class Init:
1919
func_name: str = ""
2020
package_path: str = ""
2121
module_name: str = ""
22+
types_dict_name: str = ""
23+
precision: str = ""
2224
input_args: List[str] = field(default_factory=list)
2325
output_args: List[str] = field(default_factory=list)
2426

@@ -28,10 +30,18 @@ def from_dict(obj: Any) -> "Init":
2830
_func_name = str(obj.get("func_name") or "")
2931
_package_path = str(obj.get("package_path") or "")
3032
_module_name = str(obj.get("module_name") or "")
33+
_types_dict_name = str(obj.get("types_dict_name") or "")
34+
_precision = str(obj.get("precision") or "")
3135
_input_args = obj.get("input_args")
3236
_output_args = obj.get("output_args")
3337
return Init(
34-
_func_name, _package_path, _module_name, _input_args, _output_args
38+
_func_name,
39+
_package_path,
40+
_module_name,
41+
_types_dict_name,
42+
_precision,
43+
_input_args,
44+
_output_args,
3545
)
3646

3747
def __post_init__(self):

dpbench/config/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,4 @@ class Config:
1818
frameworks: list[Framework]
1919
benchmarks: list[Benchmark]
2020
implementations: list[Implementation]
21+
dtypes: dict[str, dict[str, str]]

dpbench/config/module.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class Module:
1515
benchmark_configs_recursive: bool = False
1616
framework_configs_path: str = ""
1717
impl_postfix_path: str = ""
18+
precision_dtypes_path: str = ""
1819

1920
benchmarks_module: str = ""
2021
path: str = ""

dpbench/config/reader.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def read_configs(
3434
Returns:
3535
Configuration object with populated configurations.
3636
"""
37-
config: Config = Config([], [], [])
37+
config: Config = Config([], [], [], None)
3838

3939
dirname: str = os.path.dirname(__file__)
4040

@@ -51,6 +51,9 @@ def read_configs(
5151
impl_postfix_path=os.path.join(
5252
dirname, "../configs/impl_postfix.toml"
5353
),
54+
precision_dtypes_path=os.path.join(
55+
dirname, "../configs/precision_dtypes.toml"
56+
),
5457
),
5558
]
5659

@@ -96,6 +99,8 @@ def read_configs(
9699
read_frameworks(config, mod.framework_configs_path)
97100
if mod.impl_postfix_path != "":
98101
read_implementation_postfixes(config, mod.impl_postfix_path)
102+
if mod.precision_dtypes_path != "":
103+
read_precision_dtypes(config, mod.precision_dtypes_path)
99104
if mod.path != "":
100105
sys.path.append(mod.path)
101106

@@ -127,8 +132,6 @@ def read_benchmarks(
127132
recursive: Either to load configs recursively.
128133
parent_package: Package that contains benchmark packages.
129134
benchmarks: list of benchmarks to load. None means all.
130-
131-
Returns: nothing.
132135
"""
133136
for bench_info_file in os.listdir(bench_info_dir):
134137
bench_info_file_path = os.path.join(bench_info_dir, bench_info_file)
@@ -172,8 +175,6 @@ def read_frameworks(config: Config, framework_info_dir: str) -> None:
172175
Args:
173176
config: Configuration object where settings should be populated.
174177
framework_info_dir: Path to the directory with configuration files.
175-
176-
Returns: nothing.
177178
"""
178179
for framework_info_file in os.listdir(framework_info_dir):
179180
if not framework_info_file.endswith(".toml"):
@@ -201,8 +202,6 @@ def read_implementation_postfixes(
201202
Args:
202203
config: Configuration object where settings should be populated.
203204
impl_postfix_file: Path to the configuration file.
204-
205-
Returns: nothing.
206205
"""
207206
with open(impl_postfix_file) as file:
208207
file_contents = file.read()
@@ -213,15 +212,26 @@ def read_implementation_postfixes(
213212
config.implementations.append(implementation)
214213

215214

215+
def read_precision_dtypes(config: Config, precision_dtypes_file: str) -> None:
216+
"""Read and populate dtype_obj data types file.
217+
218+
Args:
219+
config: Configuration object where settings should be populated.
220+
precision_dtypes_file: Path to the configuration file.
221+
"""
222+
with open(precision_dtypes_file) as file:
223+
file_contents = file.read()
224+
225+
config.dtypes = tomli.loads(file_contents)
226+
227+
216228
def setup_init(config: Benchmark, modules: list[str]) -> None:
217229
"""Read and discover initialization module and function.
218230
219231
Args:
220232
config: Benchmark configuration object where settings should be
221233
populated.
222234
modules: List of available modules for the benchmark to find init.
223-
224-
Returns: nothing.
225235
"""
226236
if config.init is None:
227237
return
@@ -292,8 +302,6 @@ def read_benchmark_implementations(
292302
available implementations. It does not affect initialization import.
293303
implementations: Prepopulated list of implementations.
294304
295-
Returns: nothing.
296-
297305
Raises:
298306
RuntimeError: Implementation file does not match any known postfix.
299307
"""

0 commit comments

Comments
 (0)