Skip to content

Commit 731ea1d

Browse files
authored
[flang][runtime] Handle NAN(...) in namelist input (llvm#153101) (llvm#4354)
2 parents 8264976 + 1a1bc80 commit 731ea1d

File tree

5 files changed

+97
-38
lines changed

5 files changed

+97
-38
lines changed

flang-rt/include/flang-rt/runtime/io-stmt.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,9 @@ template <>
444444
class ListDirectedStatementState<Direction::Input>
445445
: public FormattedIoStatementState<Direction::Input> {
446446
public:
447-
RT_API_ATTRS bool inNamelistSequence() const { return inNamelistSequence_; }
447+
RT_API_ATTRS const NamelistGroup *namelistGroup() const {
448+
return namelistGroup_;
449+
}
448450
RT_API_ATTRS int EndIoStatement();
449451

450452
// Skips value separators, handles repetition and null values.
@@ -457,14 +459,15 @@ class ListDirectedStatementState<Direction::Input>
457459
// input statement. This member function resets some state so that
458460
// repetition and null values work correctly for each successive
459461
// NAMELIST input item.
460-
RT_API_ATTRS void ResetForNextNamelistItem(bool inNamelistSequence) {
462+
RT_API_ATTRS void ResetForNextNamelistItem(
463+
const NamelistGroup *namelistGroup) {
461464
remaining_ = 0;
462465
if (repeatPosition_) {
463466
repeatPosition_->Cancel();
464467
}
465468
eatComma_ = false;
466469
realPart_ = imaginaryPart_ = false;
467-
inNamelistSequence_ = inNamelistSequence;
470+
namelistGroup_ = namelistGroup;
468471
}
469472

470473
RT_API_ATTRS bool eatComma() const { return eatComma_; }
@@ -473,7 +476,7 @@ class ListDirectedStatementState<Direction::Input>
473476
RT_API_ATTRS void set_hitSlash(bool yes) { hitSlash_ = yes; }
474477

475478
protected:
476-
bool inNamelistSequence_{false};
479+
const NamelistGroup *namelistGroup_{nullptr};
477480

478481
private:
479482
int remaining_{0}; // for "r*" repetition

flang-rt/lib/runtime/edit-input.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ static RT_API_ATTRS ScannedRealInput ScanRealInput(
538538
next = io.NextInField(remaining, edit);
539539
}
540540
if (!next || *next == ')') { // NextInField fails on separators like ')'
541-
std::size_t byteCount{0};
541+
std::size_t byteCount{1};
542542
if (!next) {
543543
next = io.GetCurrentChar(byteCount);
544544
}

flang-rt/lib/runtime/io-stmt.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,7 @@ ChildListIoStatementState<DIR>::ChildListIoStatementState(
11121112
if constexpr (DIR == Direction::Input) {
11131113
if (const auto *listInput{child.parent()
11141114
.get_if<ListDirectedStatementState<Direction::Input>>()}) {
1115-
this->inNamelistSequence_ = listInput->inNamelistSequence();
1115+
this->namelistGroup_ = listInput->namelistGroup();
11161116
this->set_eatComma(listInput->eatComma());
11171117
if (auto *childListInput{child.parent()
11181118
.get_if<ChildListIoStatementState<Direction::Input>>()}) {

flang-rt/lib/runtime/namelist.cpp

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ bool IODEF(OutputNamelist)(Cookie cookie, const NamelistGroup &group) {
4444
if ((connection.NeedAdvance(prefixLen) &&
4545
!(io.AdvanceRecord() && EmitAscii(io, " ", 1))) ||
4646
!EmitAscii(io, prefix, prefixLen) ||
47-
(connection.NeedAdvance(
48-
Fortran::runtime::strlen(str) + (suffix != ' ')) &&
47+
(connection.NeedAdvance(runtime::strlen(str) + (suffix != ' ')) &&
4948
!(io.AdvanceRecord() && EmitAscii(io, " ", 1)))) {
5049
return false;
5150
}
@@ -102,8 +101,8 @@ static constexpr RT_API_ATTRS char NormalizeIdChar(char32_t ch) {
102101
return static_cast<char>(ch >= 'A' && ch <= 'Z' ? ch - 'A' + 'a' : ch);
103102
}
104103

105-
static RT_API_ATTRS bool GetLowerCaseName(
106-
IoStatementState &io, char buffer[], std::size_t maxLength) {
104+
static RT_API_ATTRS bool GetLowerCaseName(IoStatementState &io, char buffer[],
105+
std::size_t maxLength, bool crashIfTooLong = true) {
107106
std::size_t byteLength{0};
108107
if (auto ch{io.GetNextNonBlank(byteLength)}) {
109108
if (IsLegalIdStart(*ch)) {
@@ -117,8 +116,10 @@ static RT_API_ATTRS bool GetLowerCaseName(
117116
if (j <= maxLength) {
118117
return true;
119118
}
120-
io.GetIoErrorHandler().SignalError(
121-
"Identifier '%s...' in NAMELIST input group is too long", buffer);
119+
if (crashIfTooLong) {
120+
io.GetIoErrorHandler().SignalError(
121+
"Identifier '%s...' in NAMELIST input group is too long", buffer);
122+
}
122123
}
123124
}
124125
return false;
@@ -383,9 +384,8 @@ static RT_API_ATTRS bool HandleComponent(IoStatementState &io, Descriptor &desc,
383384
const DescriptorAddendum *addendum{source.Addendum()};
384385
if (const typeInfo::DerivedType *
385386
type{addendum ? addendum->derivedType() : nullptr}) {
386-
if (const typeInfo::Component *
387-
comp{type->FindDataComponent(
388-
compName, Fortran::runtime::strlen(compName))}) {
387+
if (const typeInfo::Component *comp{
388+
type->FindDataComponent(compName, runtime::strlen(compName))}) {
389389
bool createdDesc{false};
390390
if (comp->rank() > 0 && source.rank() > 0) {
391391
// If base and component are both arrays, the component name
@@ -510,7 +510,7 @@ bool IODEF(InputNamelist)(Cookie cookie, const NamelistGroup &group) {
510510
handler.SignalError("NAMELIST input group has no name");
511511
return false;
512512
}
513-
if (Fortran::runtime::strcmp(group.groupName, name) == 0) {
513+
if (runtime::strcmp(group.groupName, name) == 0) {
514514
break; // found it
515515
}
516516
SkipNamelistGroup(io);
@@ -529,7 +529,7 @@ bool IODEF(InputNamelist)(Cookie cookie, const NamelistGroup &group) {
529529
}
530530
std::size_t itemIndex{0};
531531
for (; itemIndex < group.items; ++itemIndex) {
532-
if (Fortran::runtime::strcmp(name, group.item[itemIndex].name) == 0) {
532+
if (runtime::strcmp(name, group.item[itemIndex].name) == 0) {
533533
break;
534534
}
535535
}
@@ -604,13 +604,14 @@ bool IODEF(InputNamelist)(Cookie cookie, const NamelistGroup &group) {
604604
if (const auto *addendum{useDescriptor->Addendum()};
605605
addendum && addendum->derivedType()) {
606606
const NonTbpDefinedIoTable *table{group.nonTbpDefinedIo};
607-
listInput->ResetForNextNamelistItem(/*inNamelistSequence=*/true);
607+
listInput->ResetForNextNamelistItem(&group);
608608
if (!IONAME(InputDerivedType)(cookie, *useDescriptor, table) &&
609609
handler.InError()) {
610610
return false;
611611
}
612612
} else {
613-
listInput->ResetForNextNamelistItem(useDescriptor->rank() > 0);
613+
listInput->ResetForNextNamelistItem(
614+
useDescriptor->rank() > 0 ? &group : nullptr);
614615
if (!descr::DescriptorIO<Direction::Input>(io, *useDescriptor) &&
615616
handler.InError()) {
616617
return false;
@@ -640,27 +641,51 @@ bool IODEF(InputNamelist)(Cookie cookie, const NamelistGroup &group) {
640641
}
641642

642643
RT_API_ATTRS bool IsNamelistNameOrSlash(IoStatementState &io) {
643-
if (auto *listInput{
644-
io.get_if<ListDirectedStatementState<Direction::Input>>()}) {
645-
if (listInput->inNamelistSequence()) {
646-
SavedPosition savedPosition{io};
647-
std::size_t byteCount{0};
648-
if (auto ch{io.GetNextNonBlank(byteCount)}) {
649-
if (IsLegalIdStart(*ch)) {
650-
do {
651-
io.HandleRelativePosition(byteCount);
652-
ch = io.GetCurrentChar(byteCount);
653-
} while (ch && IsLegalIdChar(*ch));
654-
ch = io.GetNextNonBlank(byteCount);
655-
// TODO: how to deal with NaN(...) ambiguity?
656-
return ch && (*ch == '=' || *ch == '(' || *ch == '%');
657-
} else {
658-
return *ch == '/' || *ch == '&' || *ch == '$';
659-
}
660-
}
644+
auto *listInput{io.get_if<ListDirectedStatementState<Direction::Input>>()};
645+
if (!listInput || !listInput->namelistGroup()) {
646+
return false; // not namelist
647+
}
648+
SavedPosition savedPosition{io};
649+
std::size_t byteCount{0};
650+
auto ch{io.GetNextNonBlank(byteCount)};
651+
if (!ch) {
652+
return false;
653+
} else if (!IsLegalIdStart(*ch)) {
654+
return *ch == '/' || *ch == '&' || *ch == '$';
655+
}
656+
char id[nameBufferSize];
657+
if (!GetLowerCaseName(io, id, sizeof id, /*crashIfTooLong=*/false)) {
658+
return true; // long name
659+
}
660+
// It looks like a name, but might be "inf" or "nan". Check what
661+
// follows.
662+
ch = io.GetNextNonBlank(byteCount);
663+
if (!ch) {
664+
return false;
665+
} else if (*ch == '=' || *ch == '%') {
666+
return true;
667+
} else if (*ch != '(') {
668+
return false;
669+
} else if (runtime::strcmp(id, "nan") != 0) {
670+
return true;
671+
}
672+
// "nan(" ambiguity
673+
int depth{1};
674+
while (true) {
675+
io.HandleRelativePosition(byteCount);
676+
ch = io.GetNextNonBlank(byteCount);
677+
if (depth == 0) {
678+
// nan(...) followed by '=', '%', or '('?
679+
break;
680+
} else if (!ch) {
681+
return true; // not a valid NaN(...)
682+
} else if (*ch == '(') {
683+
++depth;
684+
} else if (*ch == ')') {
685+
--depth;
661686
}
662687
}
663-
return false;
688+
return ch && (*ch == '=' || *ch == '%' || *ch == '(');
664689
}
665690

666691
RT_OFFLOAD_API_GROUP_END

flang-rt/unittests/Runtime/Namelist.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,4 +334,35 @@ TEST(NamelistTests, RealValueForInt) {
334334
EXPECT_EQ(got, expect);
335335
}
336336

337+
TEST(NamelistTests, NanInputAmbiguity) {
338+
OwningPtr<Descriptor> xDesc{// real :: x(5) = 0.
339+
MakeArray<TypeCategory::Real, static_cast<int>(sizeof(float))>(
340+
std::vector<int>{5}, std::vector<float>{{0, 0, 0, 0, 0}})};
341+
OwningPtr<Descriptor> nanDesc{// real :: nan(2) = 0.
342+
MakeArray<TypeCategory::Real, static_cast<int>(sizeof(float))>(
343+
std::vector<int>{2}, std::vector<float>{{0, 0}})};
344+
const NamelistGroup::Item items[]{{"x", *xDesc}, {"nan", *nanDesc}};
345+
const NamelistGroup group{"nml", 2, items};
346+
static char t1[]{"&nml x=1 2 nan(q) 4 nan(1)=5 nan(q)/"};
347+
StaticDescriptor<1, true> statDesc;
348+
Descriptor &internalDesc{statDesc.descriptor()};
349+
internalDesc.Establish(TypeCode{CFI_type_char},
350+
/*elementBytes=*/std::strlen(t1), t1, 0, nullptr, CFI_attribute_pointer);
351+
auto inCookie{IONAME(BeginInternalArrayListInput)(
352+
internalDesc, nullptr, 0, __FILE__, __LINE__)};
353+
ASSERT_TRUE(IONAME(InputNamelist)(inCookie, group));
354+
ASSERT_EQ(IONAME(EndIoStatement)(inCookie), IostatOk)
355+
<< "namelist real input for nans";
356+
char out[40];
357+
internalDesc.Establish(TypeCode{CFI_type_char}, /*elementBytes=*/sizeof out,
358+
out, 0, nullptr, CFI_attribute_pointer);
359+
auto outCookie{IONAME(BeginInternalArrayListOutput)(
360+
internalDesc, nullptr, 0, __FILE__, __LINE__)};
361+
ASSERT_TRUE(IONAME(OutputNamelist)(outCookie, group));
362+
ASSERT_EQ(IONAME(EndIoStatement)(outCookie), IostatOk) << "namelist output";
363+
std::string got{out, sizeof out};
364+
static const std::string expect{" &NML X= 1. 2. NaN 4. 0.,NAN= 5. NaN/ "};
365+
EXPECT_EQ(got, expect);
366+
}
367+
337368
// TODO: Internal NAMELIST error tests

0 commit comments

Comments
 (0)