Skip to content

Commit 2909f0c

Browse files
committed
Format python files using Black
1 parent 78ba790 commit 2909f0c

File tree

3 files changed

+262
-133
lines changed

3 files changed

+262
-133
lines changed

python/kernel_launcher/reader.py

Lines changed: 91 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ def _parse_scalar_argument(entry):
163163
return np.frombuffer(bytes(data), dtype=dtype)[0]
164164

165165

166-
def _parse_array_file(file_name: str, data_dir: str, expect_hash: str, dtype, validate: bool):
166+
def _parse_array_file(
167+
file_name: str, data_dir: str, expect_hash: str, dtype, validate: bool
168+
):
167169
file_path = os.path.join(data_dir, file_name)
168170

169171
if file_name.endswith(".gz"):
@@ -190,15 +192,21 @@ def _parse_array_argument(entry: dict, data_dir: str, validate_checksum: bool):
190192
dtype = _type_name_to_dtype(type_name[:-1])
191193

192194
if dtype is None:
193-
logger.warning(f"unknown type \"{type_name}\", falling back to byte array")
195+
logger.warning(f'unknown type "{type_name}", falling back to byte array')
194196
dtype = np.byte
195197

196-
arg = _parse_array_file(entry["file"], data_dir, entry.get("file_hash"),
197-
dtype, validate_checksum)
198+
arg = _parse_array_file(
199+
entry["file"], data_dir, entry.get("file_hash"), dtype, validate_checksum
200+
)
198201

199202
if "reference_file" in entry:
200-
answer = _parse_array_file(entry["reference_file"], data_dir, entry.get("reference_hash"),
201-
dtype, validate_checksum)
203+
answer = _parse_array_file(
204+
entry["reference_file"],
205+
data_dir,
206+
entry.get("reference_hash"),
207+
dtype,
208+
validate_checksum,
209+
)
202210
else:
203211
answer = None
204212

@@ -230,7 +238,15 @@ def __init__(self, obj, data_dir: str, validate_checksum: bool = True):
230238
self.args.append(arg)
231239
self.answers.append(answer)
232240

233-
def _tune_options(self, working_dir=None, lang="cupy", compiler_options=None, defines=None, device=0, **kwargs):
241+
def _tune_options(
242+
self,
243+
working_dir=None,
244+
lang="cupy",
245+
compiler_options=None,
246+
defines=None,
247+
device=0,
248+
**kwargs,
249+
):
234250
if working_dir is None:
235251
working_dir = os.getcwd()
236252

@@ -277,20 +293,21 @@ def grid_size(config):
277293
restrictions = [e.resolve(**context) for e in self.space.restrictions]
278294

279295
options = dict(
280-
kernel_name=self.kernel.generate_name(),
281-
kernel_source=self.kernel.generate_source(working_dir),
282-
arguments=self.args,
283-
problem_size=grid_size,
284-
restrictions=lambda config: all(f(config) for f in restrictions),
285-
defines=all_defines,
286-
compiler_options=compiler_options,
287-
block_size_names=block_size_names,
288-
grid_div_x=[],
289-
grid_div_y=[],
290-
grid_div_z=[],
291-
lang=lang,
292-
device=device,
293-
**kwargs)
296+
kernel_name=self.kernel.generate_name(),
297+
kernel_source=self.kernel.generate_source(working_dir),
298+
arguments=self.args,
299+
problem_size=grid_size,
300+
restrictions=lambda config: all(f(config) for f in restrictions),
301+
defines=all_defines,
302+
compiler_options=compiler_options,
303+
block_size_names=block_size_names,
304+
grid_div_x=[],
305+
grid_div_y=[],
306+
grid_div_z=[],
307+
lang=lang,
308+
device=device,
309+
**kwargs,
310+
)
294311

295312
os.chdir(working_dir)
296313
return extra_params, options
@@ -351,16 +368,25 @@ def tune(self, params=None, **kwargs):
351368
strategy = "brute_force" if total_configs < 100 else "bayes_opt"
352369

353370
return kernel_tuner.tune_kernel(
354-
tune_params=params,
355-
strategy=strategy,
356-
answer=answer,
357-
verify=verify,
358-
**options)
371+
tune_params=params,
372+
strategy=strategy,
373+
answer=answer,
374+
verify=verify,
375+
**options,
376+
)
359377

360378

361379
def _fancy_verify(answers, outputs, *, atol=None):
362-
INTEGRAL_DTYPES = [np.int8, np.int16, np.int32, np.int64,
363-
np.uint8, np.uint16, np.uint32, np.uint64]
380+
INTEGRAL_DTYPES = [
381+
np.int8,
382+
np.int16,
383+
np.int32,
384+
np.int64,
385+
np.uint8,
386+
np.uint16,
387+
np.uint32,
388+
np.uint64,
389+
]
364390
FLOATING_DTYPES = [np.float16, np.float32, np.float64]
365391
PRINT_TOP_VALUES = 25
366392
DEFAULT_ATOL = 1e-8
@@ -378,7 +404,9 @@ def _fancy_verify(answers, outputs, *, atol=None):
378404
continue
379405

380406
if output.dtype != expected.dtype or output.shape != expected.shape:
381-
raise RuntimeError(f"arrays data type or shape do not match: {output} and {expected}")
407+
raise RuntimeError(
408+
f"arrays data type or shape do not match: {output} and {expected}"
409+
)
382410

383411
if output.dtype in INTEGRAL_DTYPES:
384412
matches = output == expected
@@ -401,14 +429,18 @@ def _fancy_verify(answers, outputs, *, atol=None):
401429
# indices = indices[np.argsort(errors[indices], kind="stable")][::-1]
402430

403431
percentage = nerrors / len(output) * 100
404-
print(f"argument {index + 1} fails validation: {nerrors} incorrect values" +
405-
f"({percentage:.5}%)")
432+
print(
433+
f"argument {index + 1} fails validation: {nerrors} incorrect values"
434+
+ f"({percentage:.5}%)"
435+
)
406436

407437
errors = np.abs(output - expected)
408438

409439
for index in indices[:PRINT_TOP_VALUES]:
410-
print(f" * at index {index}: {output[index]} != {expected[index]} " +
411-
f"(error: {errors[index]})")
440+
print(
441+
f" * at index {index}: {output[index]} != {expected[index]} "
442+
+ f"(error: {errors[index]})"
443+
)
412444

413445
if nerrors > PRINT_TOP_VALUES:
414446
print(f" * ({nerrors - PRINT_TOP_VALUES} more entries have been omitted)")
@@ -583,22 +615,27 @@ def resolve(self, problem_size, **kwargs):
583615

584616
class DeviceAttributeExpr(Expr):
585617
# Map cuda.h names to cupy names
586-
NAME_MAPPING = dict([
587-
('MAX_THREADS_PER_BLOCK', 'MaxThreadsPerBlock'),
588-
('MAX_BLOCK_DIM_X', 'MaxBlockDimX'),
589-
('MAX_BLOCK_DIM_Y', 'MaxBlockDimY'),
590-
('MAX_BLOCK_DIM_Z', 'MaxBlockDimZ'),
591-
('MAX_GRID_DIM_X', 'MaxGridDimX'),
592-
('MAX_GRID_DIM_Y', 'MaxGridDimY'),
593-
('MAX_GRID_DIM_Z', 'MaxGridDimZ'),
594-
('MAX_SHARED_MEMORY_PER_BLOCK', 'MaxSharedMemoryPerBlock'),
595-
('WARP_SIZE', 'WarpSize'),
596-
('MAX_REGISTERS_PER_BLOCK', 'MaxRegistersPerBlock'),
597-
('MULTIPROCESSOR_COUNT', 'MultiProcessorCount'),
598-
('MAX_THREADS_PER_MULTIPROCESSOR', 'MaxThreadsPerMultiProcessor'),
599-
('MAX_SHARED_MEMORY_PER_MULTIPROCESSOR', 'MaxSharedMemoryPerMultiprocessor'),
600-
('MAX_REGISTERS_PER_MULTIPROCESSOR', 'MaxRegistersPerMultiprocessor'),
601-
])
618+
NAME_MAPPING = dict(
619+
[
620+
("MAX_THREADS_PER_BLOCK", "MaxThreadsPerBlock"),
621+
("MAX_BLOCK_DIM_X", "MaxBlockDimX"),
622+
("MAX_BLOCK_DIM_Y", "MaxBlockDimY"),
623+
("MAX_BLOCK_DIM_Z", "MaxBlockDimZ"),
624+
("MAX_GRID_DIM_X", "MaxGridDimX"),
625+
("MAX_GRID_DIM_Y", "MaxGridDimY"),
626+
("MAX_GRID_DIM_Z", "MaxGridDimZ"),
627+
("MAX_SHARED_MEMORY_PER_BLOCK", "MaxSharedMemoryPerBlock"),
628+
("WARP_SIZE", "WarpSize"),
629+
("MAX_REGISTERS_PER_BLOCK", "MaxRegistersPerBlock"),
630+
("MULTIPROCESSOR_COUNT", "MultiProcessorCount"),
631+
("MAX_THREADS_PER_MULTIPROCESSOR", "MaxThreadsPerMultiProcessor"),
632+
(
633+
"MAX_SHARED_MEMORY_PER_MULTIPROCESSOR",
634+
"MaxSharedMemoryPerMultiprocessor",
635+
),
636+
("MAX_REGISTERS_PER_MULTIPROCESSOR", "MaxRegistersPerMultiprocessor"),
637+
]
638+
)
602639

603640
def __init__(self, name):
604641
self.name = name
@@ -630,8 +667,10 @@ def evaluate(self, config):
630667
index = self.condition.evaluate(config)
631668

632669
if not is_int_like(index) or index < 0 or index >= len(self.options):
633-
raise RuntimeError("expression must yield an integer in " +
634-
f"range 0..{len(self.options)}: {self}")
670+
raise RuntimeError(
671+
"expression must yield an integer in "
672+
+ f"range 0..{len(self.options)}: {self}"
673+
)
635674

636675
return self.options[int(index)].evaluate(config)
637676

@@ -641,8 +680,7 @@ def visit_children(self, fun):
641680

642681
def _parse_expr(entry) -> Expr:
643682
# literal int, str or float becomes ValueExpr.
644-
if isinstance(entry, (int, str, float)) or \
645-
entry is None:
683+
if isinstance(entry, (int, str, float)) or entry is None:
646684
return ValueExpr(entry)
647685

648686
# Otherwise it must be an operator expression

python/kernel_launcher/wisdom.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,39 @@
1717
WISDOM_OBJECTIVE = "time"
1818

1919

20-
def write_wisdom_for_problem(path: str, problem: TuningProblem, results: list,
21-
env: dict, **kwargs):
22-
"""Write the results of ``TuningProblem.tune`` to a wisdom file. This function calls ``write_wisdom``
20+
def write_wisdom_for_problem(
21+
path: str, problem: TuningProblem, results: list, env: dict, **kwargs
22+
):
23+
"""Write the results of ``TuningProblem.tune`` to a wisdom file. This function just wraps ``write_wisdom``
2324
2425
:param path: Directory were wisdom files are stored. Alternatively, this can be a file name ending with ``.wisdom``.
2526
:param problem: The ``TuningProblem``.
2627
:param results: The results returned by ``kernel_tuner.tune_kernel``.
2728
:param env: The environment returned by ``kernel_tuner.tune_kernel``.
2829
:param kwargs: Additional keyword arguments passed to `write_wisdom`.
2930
"""
30-
return write_wisdom(path, problem.key, problem.space.params, problem.problem_size, results, env, **kwargs)
31-
32-
33-
def write_wisdom(path: str, key: str, params: dict, problem_size: list, results: list, env: dict, *,
34-
max_results=5, merge_existing_results=False):
31+
return write_wisdom(
32+
path,
33+
problem.key,
34+
problem.space.params,
35+
problem.problem_size,
36+
results,
37+
env,
38+
**kwargs,
39+
)
40+
41+
42+
def write_wisdom(
43+
path: str,
44+
key: str,
45+
params: dict,
46+
problem_size: list,
47+
results: list,
48+
env: dict,
49+
*,
50+
max_results=5,
51+
merge_existing_results=False,
52+
):
3553
"""Write the results of ``kernel_tuner.tune_kernel`` to a wisdom file.
3654
3755
:param path: Directory were wisdom files are stored. Alternatively, this can be a file name ending with ``.wisdom``.
@@ -42,7 +60,8 @@ def write_wisdom(path: str, key: str, params: dict, problem_size: list, results:
4260
:param env: The environment returned by ``kernel_tuner.tune_kernel``
4361
:param max_results: Only the top ``max_results`` results are written in the wisdom file.
4462
:param merge_existing_results: If ``True``, existing results in the wisdom file for the same problem size and
45-
environment are merged with the provided ``results``.
63+
environment are merged with the provided ``results``. If ``False``, existing
64+
results for the same problem size and environment are overwritten.
4665
4766
"""
4867
device_name = env["device_name"]
@@ -60,9 +79,11 @@ def write_wisdom(path: str, key: str, params: dict, problem_size: list, results:
6079

6180
for line, record in _parse_wisdom(handle):
6281
# Skip lines that have a matching problem_size and device_name
63-
if not record or \
64-
record["problem_size"] != problem_size or \
65-
record["environment"].get("device_name") != device_name:
82+
if (
83+
not record
84+
or record["problem_size"] != problem_size
85+
or record["environment"].get("device_name") != device_name
86+
):
6687
lines.append(line)
6788
elif merge_existing_results:
6889
index = tuple(record["config"])
@@ -105,7 +126,9 @@ def write_wisdom(path: str, key: str, params: dict, problem_size: list, results:
105126
handle.writelines(line + "\n" for line in lines)
106127

107128

108-
def read_wisdom(path: str, key: str = None, params: dict = None, *, error_if_missing: bool = True) -> list:
129+
def read_wisdom(
130+
path: str, key: str = None, params: dict = None, *, error_if_missing: bool = True
131+
) -> list:
109132
"""
110133
Read the results of a wisdom file.
111134
@@ -168,26 +191,31 @@ def _check_header(line: str, key: str, params: dict):
168191
if key is not None:
169192
if data.get("key") != key:
170193
print(data)
171-
raise RuntimeError(f"invalid key in wisdom file: {key} != {data.get('key')}")
194+
raise RuntimeError(
195+
f"invalid key in wisdom file: {key} != {data.get('key')}"
196+
)
172197

173198
keys = data.get("tunable_parameters", [])
174199
if params is not None:
175200
if set(keys) != set(params.keys()):
176-
raise RuntimeError("invalid tunable parameters in wisdom file: " +
177-
f"{list(params.keys())} != {keys}")
201+
raise RuntimeError(
202+
"invalid tunable parameters in wisdom file: "
203+
+ f"{list(params.keys())} != {keys}"
204+
)
178205

179206
return keys
180207

181208

182209
def _create_header(key: str, param_keys: list) -> str:
183-
"""Header of wisdom file (ie, the first line).
184-
"""
185-
return json.dumps({
186-
"version": WISDOM_VERSION,
187-
"objective": WISDOM_OBJECTIVE,
188-
"tunable_parameters": list(param_keys),
189-
"key": key,
190-
})
210+
"""Header of wisdom file (ie, the first line)."""
211+
return json.dumps(
212+
{
213+
"version": WISDOM_VERSION,
214+
"objective": WISDOM_OBJECTIVE,
215+
"tunable_parameters": list(param_keys),
216+
"key": key,
217+
}
218+
)
191219

192220

193221
def _is_valid_config(config):
@@ -230,7 +258,9 @@ def _wisdom_file(path, key):
230258
filename = re.sub("[^0-9a-zA-Z_.-]", "_", key) + ".wisdom"
231259
return os.path.join(path, filename)
232260
else:
233-
raise ValueError(f"path must be a directory or a file ending with .wisdom: {path}")
261+
raise ValueError(
262+
f"path must be a directory or a file ending with .wisdom: {path}"
263+
)
234264

235265

236266
def _build_environment(env):
@@ -243,13 +273,15 @@ def _build_environment(env):
243273
# Kernel tuner related
244274
try:
245275
import kernel_tuner
276+
246277
env["kernel_tuner_version"] = kernel_tuner.__version__
247278
except AttributeError as e:
248279
logger.warning(f"ignore error: kernel_tuner.__version__ is not available: {e}")
249280

250281
# CUDA related
251282
try:
252283
import pycuda.driver
284+
253285
env["cuda_driver_version"] = pycuda.driver.get_driver_version()
254286

255287
major, minor, patch = pycuda.driver.get_version()

0 commit comments

Comments
 (0)