Skip to content

Commit 9ce7cd2

Browse files
committed
Added tests for time budget
1 parent c1b447c commit 9ce7cd2

File tree

2 files changed

+84
-7
lines changed

2 files changed

+84
-7
lines changed

test/test_compiler_functions.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,21 @@
1-
from datetime import datetime
21

3-
import numpy as np
42
import ctypes as C
3+
4+
import numpy as np
55
import pytest
66
from pytest import raises
77

88
try:
9-
from mock import patch, Mock
9+
from mock import Mock, patch
1010
except ImportError:
11-
from unittest.mock import patch, Mock
11+
from unittest.mock import Mock, patch
1212

1313
import kernel_tuner
14-
from kernel_tuner.backends.compiler import CompilerFunctions, Argument, is_cupy_array, get_array_module
15-
from kernel_tuner.core import KernelSource, KernelInstance
1614
from kernel_tuner import util
15+
from kernel_tuner.backends.compiler import Argument, CompilerFunctions, get_array_module, is_cupy_array
16+
from kernel_tuner.core import KernelInstance, KernelSource
1717

18-
from .context import skip_if_no_gfortran, skip_if_no_gcc, skip_if_no_openmp, skip_if_no_cupy
18+
from .context import skip_if_no_cupy, skip_if_no_gcc, skip_if_no_gfortran, skip_if_no_openmp
1919
from .test_runners import env as cuda_env # noqa: F401
2020

2121

test/test_time_budgets.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from itertools import product
2+
from time import perf_counter
3+
4+
import numpy as np
5+
import pytest
6+
from pytest import raises
7+
8+
from kernel_tuner import tune_kernel
9+
10+
from .context import skip_if_no_gcc
11+
12+
13+
@pytest.fixture
14+
def env():
15+
kernel_name = "vector_add"
16+
kernel_string = """
17+
#include <time.h>
18+
19+
float vector_add(float *c, float *a, float *b, int n) {
20+
struct timespec start, end;
21+
clock_gettime(CLOCK_MONOTONIC, &start);
22+
23+
for (int i = 0; i < n; i++) {
24+
c[i] = a[i] + b[i];
25+
}
26+
27+
clock_gettime(CLOCK_MONOTONIC, &end);
28+
double elapsed = (end.tv_sec - start.tv_sec) * 1e3 + (end.tv_nsec - start.tv_nsec) / 1e6;
29+
return (float) elapsed;
30+
}"""
31+
32+
size = 100
33+
a = np.random.randn(size).astype(np.float32)
34+
b = np.random.randn(size).astype(np.float32)
35+
c = np.zeros_like(b)
36+
n = np.int32(size)
37+
38+
args = [c, a, b, n]
39+
tune_params = {"nthreads": [1, 2, 4]}
40+
41+
return kernel_name, kernel_string, size, args, tune_params
42+
43+
44+
@skip_if_no_gcc
45+
def test_no_time_budget(env):
46+
"""Ensure that a RuntimeError is raised if the startup takes longer than the time budget."""
47+
with raises(RuntimeError, match='startup time of the tuning process'):
48+
tune_kernel(*env, strategy="random_sample", strategy_options={"strategy": "random_sample", "time_limit": 0.0})
49+
50+
@skip_if_no_gcc
51+
def test_some_time_budget(env):
52+
"""Ensure that the time limit is respected."""
53+
time_limit = 1.0
54+
kernel_name, kernel_string, size, args, tune_params = env
55+
tune_params["bogus"] = list(range(1000))
56+
env = kernel_name, kernel_string, size, args, tune_params
57+
58+
# Ensure that if the tuning takes longer than the time budget, the results are returned early.
59+
start_time = perf_counter()
60+
res, _ = tune_kernel(*env, strategy="random_sample", strategy_options={"time_limit": time_limit})
61+
62+
# Ensure that there are at least some results, but not all.
63+
size_all = len(list(product(*tune_params.values())))
64+
assert 0 < len(res) < size_all
65+
66+
# Ensure that the time limit was respected by some margin.
67+
assert perf_counter() - start_time < time_limit * 2
68+
69+
@skip_if_no_gcc
70+
def test_full_time_budget(env):
71+
"""Ensure that given ample time budget, the entire space is explored."""
72+
res, _ = tune_kernel(*env, strategy="brute_force", strategy_options={"time_limit": 10.0})
73+
74+
# Ensure that the entire space is explored.
75+
tune_params = env[-1]
76+
size_all = len(list(product(*tune_params.values())))
77+
assert len(res) == size_all

0 commit comments

Comments
 (0)