Skip to content

Commit ec91015

Browse files
authored
Increased Tuner Safety (CNugteren#616)
* Added safety to ensure the user does not supply too little threads from the --threads parameter. * Adjusted default arg type based on the change in utilities/utilities.hpp * Fixed compilation errors
1 parent 366cff2 commit ec91015

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

src/tuning/tuning.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ void Tuner(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDef
220220
args.device_id =
221221
GetArgument(command_line_args, help, kArgDevice, ConvertArgument(std::getenv("CLBLAST_DEVICE"), size_t{0}));
222222
args.precision = GetArgument(command_line_args, help, kArgPrecision, Precision::kSingle);
223-
args.extra_threads = GetArgument(command_line_args, help, kArgNumThreads, size_t{1}) - 1;
223+
args.extra_threads = GetArgument(command_line_args, help, kArgNumThreads, 1) - 1;
224224
for (auto& o : defaults.options) {
225225
if (o == kArgM) {
226226
args.m = GetArgument(command_line_args, help, kArgM, defaults.default_m);
@@ -272,6 +272,12 @@ void Tuner(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDef
272272
const TunerSettings settings = GetTunerSettings(V, args);
273273

274274
// Tests validity of the given arguments
275+
if (args.extra_threads < 0) {
276+
printf("Tuners cannot run without threads or negative threads. Provided %lli threads. Exiting...\n",
277+
args.extra_threads);
278+
return;
279+
}
280+
275281
TestValidArguments(V, args);
276282

277283
// Initializes OpenCL
@@ -401,7 +407,7 @@ void Tuner(int argc, char* argv[], const int V, GetTunerDefaultsFunc GetTunerDef
401407
std::vector<ThreadInfo> thread_infos(configurations.size());
402408
std::vector<std::thread> threads;
403409
threads.reserve(args.extra_threads);
404-
for (size_t i = 0; i < std::min(size_t{args.extra_threads}, configurations.size()); ++i) {
410+
for (size_t i = 0; i < std::min(static_cast<size_t>(args.extra_threads), configurations.size()); ++i) {
405411
threads.push_back(std::thread(&kernelCompilationThread<T>, std::ref(thread_infos), std::cref(configurations), i,
406412
std::cref(settings), std::cref(args), std::cref(device), std::cref(context),
407413
args.extra_threads));

src/utilities/utilities.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ struct Arguments {
266266
// Common arguments
267267
size_t platform_id = 0;
268268
size_t device_id = 0;
269-
size_t extra_threads = 1;
269+
long long extra_threads = 1;
270270
Precision precision = Precision::kSingle;
271271
bool print_help = false;
272272
bool silent = false;

0 commit comments

Comments
 (0)