Skip to content

Commit 2c6a397

Browse files
authored
[SYCL] Add validation + exception handling to SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS (#19381)
Bad / nonsensical values in `SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS` result in division by 0 crashes. This PR: - Adds guards to check that range rounding factors are set to sensical values - Adds test to ensure only valid range rounding factors can be set
1 parent 92a5652 commit 2c6a397

File tree

2 files changed

+107
-12
lines changed

2 files changed

+107
-12
lines changed

sycl/source/detail/config.hpp

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -199,35 +199,64 @@ template <> class SYCLConfig<SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS> {
199199
private:
200200
public:
201201
static void GetSettings(size_t &MinFactor, size_t &GoodFactor,
202-
size_t &MinRange) {
203-
static const char *RoundParams = BaseT::getRawValue();
202+
size_t &MinRange, bool ForceUpdate = false) {
203+
const char *RoundParams = BaseT::getRawValue();
204204
if (RoundParams == nullptr)
205205
return;
206206

207207
static bool ProcessedFactors = false;
208+
static bool FactorsAreValid = false;
208209
static size_t MF;
209210
static size_t GF;
210211
static size_t MR;
211-
if (!ProcessedFactors) {
212+
if (!ProcessedFactors || ForceUpdate) {
213+
auto GuardedStoi = [](size_t &val, const std::string &str) {
214+
try {
215+
int ParsedResult = std::stoi(str);
216+
if (ParsedResult < 0)
217+
return false;
218+
val = ParsedResult;
219+
return true;
220+
// Ignore parsing exceptions, but throw on unexpected exceptions:
221+
} catch (const std::invalid_argument &) {
222+
} catch (const std::out_of_range &) {
223+
}
224+
return false;
225+
};
226+
212227
// Parse optional parameters of this form (all values required):
213228
// MinRound:PreferredRound:MinRange
214229
std::string Params(RoundParams);
215230
size_t Pos = Params.find(':');
216-
if (Pos != std::string::npos) {
217-
MF = std::stoi(Params.substr(0, Pos));
231+
if (Pos != std::string::npos && GuardedStoi(MF, Params.substr(0, Pos)) &&
232+
MF > 0) {
218233
Params.erase(0, Pos + 1);
219234
Pos = Params.find(':');
220-
if (Pos != std::string::npos) {
221-
GF = std::stoi(Params.substr(0, Pos));
235+
if (Pos != std::string::npos &&
236+
GuardedStoi(GF, Params.substr(0, Pos)) && GF > 0) {
222237
Params.erase(0, Pos + 1);
223-
MR = std::stoi(Params);
238+
// Factors are valid only if all parsed successfully:
239+
FactorsAreValid = GuardedStoi(MR, Params);
240+
// Note that MinRange = 0 is considered valid.
224241
}
225242
}
226-
ProcessedFactors = true;
243+
if (FactorsAreValid) {
244+
ProcessedFactors = true;
245+
} else {
246+
std::cerr
247+
<< "WARNING: Invalid value passed for "
248+
<< "SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS (Expected format "
249+
<< "MinRound:PreferredRound:MinRange, where MinRound, "
250+
"PreferredRound"
251+
<< " > 0, MinRange >= 0). Provided parameters will be ignored."
252+
<< std::endl;
253+
}
254+
}
255+
if (FactorsAreValid) {
256+
MinFactor = MF;
257+
GoodFactor = GF;
258+
MinRange = MR;
227259
}
228-
MinFactor = MF;
229-
GoodFactor = GF;
230-
MinRange = MR;
231260
}
232261
};
233262

sycl/unittests/config/ConfigTests.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,3 +449,69 @@ TEST(ConfigTests, CheckPersistentCacheEvictionThresholdTest) {
449449
OnDiskEvicType::reset();
450450
TestConfig(0);
451451
}
452+
453+
// SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS accepts ...
454+
TEST(ConfigTests, CheckParallelForRangeRoundingParams) {
455+
456+
// Lambda to set SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS.
457+
auto SetRoundingParams = [](const char *value) {
458+
#ifdef _WIN32
459+
_putenv_s("SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS", value);
460+
#else
461+
setenv("SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS", value, 1);
462+
#endif
463+
sycl::detail::readConfig(true);
464+
};
465+
466+
// Lambda to assert test parameters are as expected.
467+
auto AssertRoundingParams = [](size_t MF, size_t GF, size_t MR,
468+
const char *errMsg, bool ForceUpdate = false) {
469+
size_t ResultMF = 0, ResultGF = 0, ResultMR = 0;
470+
SYCLConfig<SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS>::GetSettings(
471+
ResultMF, ResultGF, ResultMR, ForceUpdate);
472+
EXPECT_EQ(MF, ResultMF) << errMsg;
473+
EXPECT_EQ(GF, ResultGF) << errMsg;
474+
EXPECT_EQ(MR, ResultMR) << errMsg;
475+
};
476+
477+
// Lambda to test invalid input -- factors should remain unchanged.
478+
auto TestBadInput = [&](const char *value, const char *errMsg) {
479+
// Original factor values are stored as its own variable as size of size_t
480+
// varies depending on system and architecture:
481+
constexpr size_t MF = 1, GF = 2, MR = 3;
482+
size_t TestMF = MF, TestGF = GF, TestMR = MR;
483+
SetRoundingParams(value);
484+
SYCLConfig<SYCL_PARALLEL_FOR_RANGE_ROUNDING_PARAMS>::GetSettings(
485+
TestMF, TestGF, TestMR, true);
486+
EXPECT_EQ(TestMF, MF) << errMsg;
487+
EXPECT_EQ(TestGF, GF) << errMsg;
488+
EXPECT_EQ(TestMR, MR) << errMsg;
489+
};
490+
491+
// Test malformed input:
492+
constexpr char MalformedErr[] =
493+
"Rounding parameters should be ignored on malformed input";
494+
TestBadInput("abc", MalformedErr);
495+
TestBadInput("42", MalformedErr);
496+
TestBadInput(":7", MalformedErr);
497+
TestBadInput("7:", MalformedErr);
498+
TestBadInput("1:2", MalformedErr);
499+
TestBadInput("1:2:", MalformedErr);
500+
TestBadInput("1:abc:3", MalformedErr);
501+
502+
// Test well-formed input, but bad parameters:
503+
constexpr char BadParamsErr[] = "Rounding parameters should be ignored if "
504+
"parameters provided are invalid";
505+
TestBadInput("0:1:2", BadParamsErr);
506+
TestBadInput("1:0:2", BadParamsErr);
507+
TestBadInput("-1:2:3", BadParamsErr);
508+
TestBadInput("1:2:31415926535897932384626433832795028841971", BadParamsErr);
509+
510+
// Test valid values.
511+
SetRoundingParams("8:16:32");
512+
AssertRoundingParams(8, 16, 32,
513+
"Failed to read rounding parameters properly");
514+
SetRoundingParams("8:16:0");
515+
AssertRoundingParams(8, 16, 0, "0 is a valid value for MinRange",
516+
/*ForceUpdate =*/true);
517+
}

0 commit comments

Comments
 (0)