Skip to content

Commit e56c772

Browse files
committed
Combine pruneInterface() + validateConnectivity()
Validate interfaces during resolution. No need to separately validate interfaces. Thus, validateConnectivity() is removed from Task::init(). Functions are kept (but simplified) for unit testing.
1 parent 10260b9 commit e56c772

File tree

5 files changed

+68
-112
lines changed

5 files changed

+68
-112
lines changed

core/include/moveit/task_constructor/container_p.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,6 @@ class ContainerBasePrivate : public StagePrivate
132132
bool allowed = (required & WRITES_NEXT_START);
133133
child.pimpl()->setNextStarts(allowed ? pending_forward_ : InterfacePtr());
134134
}
135-
// report error about mismatching interface (start or end as determined by mask)
136-
template <unsigned int mask>
137-
void mismatchingInterface(InitStageException& errors, const StagePrivate& child) const;
138135

139136
/// copy external_state to a child's interface and remember the link in internal_to map
140137
void copyState(Interface::iterator external, const InterfacePtr& target, bool updated);
@@ -184,6 +181,10 @@ class SerialContainerPrivate : public ContainerBasePrivate
184181
protected:
185182
// connect two neighbors
186183
void connect(StagePrivate& stage1, StagePrivate& stage2);
184+
185+
// validate that child's interface matches mine (considering start or end only as determined by mask)
186+
template <unsigned int mask>
187+
void validateInterface(const StagePrivate& child, InterfaceFlags required) const;
187188
};
188189
PIMPL_FUNCTIONS(SerialContainer)
189190

@@ -220,6 +221,9 @@ class ParallelContainerBasePrivate : public ContainerBasePrivate
220221

221222
void validateConnectivity() const override;
222223

224+
protected:
225+
void validateInterfaces(const StagePrivate& child, InterfaceFlags& external, bool first = false) const;
226+
223227
private:
224228
/// callback for new externally received states
225229
void onNewExternalState(Interface::Direction dir, Interface::iterator external, bool updated);

core/include/moveit/task_constructor/stage_p.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,8 +222,6 @@ class PropagatingEitherWayPrivate : public ComputeBasePrivate
222222
void initInterface(PropagatingEitherWay::Direction dir);
223223
// prune interface to the given propagation direction
224224
void pruneInterface(InterfaceFlags accepted) override;
225-
// validate that we can propagate in one direction at least
226-
void validateConnectivity() const override;
227225

228226
bool canCompute() const override;
229227
void compute() override;

core/src/container.cpp

Lines changed: 61 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -91,25 +91,9 @@ bool ContainerBasePrivate::traverseStages(const ContainerBase::StageCallback& pr
9191
}
9292

9393
void ContainerBasePrivate::validateConnectivity() const {
94-
InitStageException errors;
9594
// recursively validate all children and accumulate errors
96-
for (const auto& child : children()) {
97-
try {
98-
child->pimpl()->validateConnectivity();
99-
} catch (InitStageException& e) {
100-
errors.append(e);
101-
}
102-
}
103-
if (errors)
104-
throw errors;
105-
}
106-
107-
template <unsigned int mask>
108-
void ContainerBasePrivate::mismatchingInterface(InitStageException& errors, const StagePrivate& child) const {
109-
boost::format desc("%1% interface of '%2%' (%3%) does not match mine (%4%)");
110-
errors.push_back(*me(), (desc % (mask == START_IF_MASK ? "start" : "end") % child.name() %
111-
flowSymbol<mask>(child.interfaceFlags()) % flowSymbol<mask>(interfaceFlags()))
112-
.str());
95+
for (const auto& child : children())
96+
child->pimpl()->validateConnectivity();
11397
}
11498

11599
bool ContainerBasePrivate::canCompute() const {
@@ -416,13 +400,27 @@ void SerialContainerPrivate::connect(StagePrivate& stage1, StagePrivate& stage2)
416400
else if ((flags1 & READS_END) && (flags2 & WRITES_PREV_END))
417401
stage2.setPrevEnds(stage1.ends());
418402
else {
419-
boost::format desc("end interface of '%1%' (%2%) does not match start interface of '%3%' (%4%)");
403+
boost::format desc("cannot connect end interface of '%1%' (%2%) to start interface of '%3%' (%4%)");
420404
desc % stage1.name() % flowSymbol<END_IF_MASK>(flags1);
421405
desc % stage2.name() % flowSymbol<START_IF_MASK>(flags2);
422406
throw InitStageException(*me(), desc.str());
423407
}
424408
}
425409

410+
template <unsigned int mask>
411+
void SerialContainerPrivate::validateInterface(const StagePrivate& child, InterfaceFlags required) const {
412+
required = required & mask;
413+
if (required == UNKNOWN)
414+
return; // cannot yet validate
415+
InterfaceFlags child_interface = child.interfaceFlags() & mask;
416+
if (required != child_interface) {
417+
boost::format desc("%1% interface (%3%) of '%2%' does not match mine (%4%)");
418+
desc % (mask == START_IF_MASK ? "start" : "end") % child.name();
419+
desc % flowSymbol<mask>(child_interface) % flowSymbol<mask>(required);
420+
throw InitStageException(*me_, desc.str());
421+
}
422+
}
423+
426424
// called by parent asking for pruning of this' interface
427425
void SerialContainerPrivate::pruneInterface(InterfaceFlags accepted) {
428426
// we need to have some children to do the actual work
@@ -439,6 +437,8 @@ void SerialContainerPrivate::pruneInterface(InterfaceFlags accepted) {
439437
first.pimpl()->pruneInterface(accepted & START_IF_MASK);
440438
// connect first child's (start) push interface
441439
setChildsPushBackwardInterface(first);
440+
// validate that first child's and this container's start interfaces match
441+
validateInterface<START_IF_MASK>(*first.pimpl(), accepted);
442442
// connect first child's (start) pull interface
443443
if (const InterfacePtr& target = first.pimpl()->starts())
444444
starts_.reset(new Interface(
@@ -454,6 +454,8 @@ void SerialContainerPrivate::pruneInterface(InterfaceFlags accepted) {
454454

455455
// connect last child's (end) push interface
456456
setChildsPushForwardInterface(last);
457+
// validate that last child's and this container's end interfaces match
458+
validateInterface<END_IF_MASK>(*last.pimpl(), accepted);
457459
// connect last child's (end) pull interface
458460
if (const InterfacePtr& target = last.pimpl()->ends())
459461
ends_.reset(new Interface(
@@ -463,30 +465,14 @@ void SerialContainerPrivate::pruneInterface(InterfaceFlags accepted) {
463465
}
464466

465467
void SerialContainerPrivate::validateConnectivity() const {
466-
InitStageException errors;
467-
468-
// recursively validate children
469-
try {
470-
ContainerBasePrivate::validateConnectivity();
471-
} catch (InitStageException& e) {
472-
errors.append(e);
473-
}
468+
ContainerBasePrivate::validateConnectivity();
474469

470+
InterfaceFlags mine = interfaceFlags();
475471
// check that input / output interface of first / last child matches this' resp. interface
476-
if (!children().empty()) {
477-
const StagePrivate* start = children().front()->pimpl();
478-
const auto my_flags = this->interfaceFlags();
479-
auto child_flags = start->interfaceFlags() & START_IF_MASK;
480-
if (child_flags != (my_flags & START_IF_MASK))
481-
mismatchingInterface<START_IF_MASK>(errors, *start);
482-
483-
const StagePrivate* last = children().back()->pimpl();
484-
child_flags = last->interfaceFlags() & END_IF_MASK;
485-
if (child_flags != (my_flags & END_IF_MASK))
486-
mismatchingInterface<END_IF_MASK>(errors, *last);
487-
}
472+
validateInterface<START_IF_MASK>(*children().front()->pimpl(), mine);
473+
validateInterface<END_IF_MASK>(*children().back()->pimpl(), mine);
488474

489-
// validate connectivity of children amongst each other
475+
// validate connectivity of children between each other
490476
// ContainerBasePrivate::validateConnectivity() ensures that required push interfaces are present,
491477
// that is, neighbouring stages have a corresponding pull interface.
492478
// Here, it remains to check that - if a child has a pull interface - it's indeed feeded.
@@ -503,16 +489,13 @@ void SerialContainerPrivate::validateConnectivity() const {
503489
// start pull interface fed?
504490
if (cur != children().begin() && // first child has not a previous one
505491
(required & READS_START) && !(*prev)->pimpl()->nextStarts())
506-
errors.push_back(**cur, "start interface is not fed");
492+
throw InitStageException(**cur, "start interface is not fed");
507493

508494
// end pull interface fed?
509495
if (next != end && // last child has not a next one
510496
(required & READS_END) && !(*next)->pimpl()->prevEnds())
511-
errors.push_back(**cur, "end interface is not fed");
497+
throw InitStageException(**cur, "end interface is not fed");
512498
}
513-
514-
if (errors)
515-
throw errors;
516499
}
517500

518501
bool SerialContainer::canCompute() const {
@@ -575,80 +558,68 @@ void ParallelContainerBasePrivate::pruneInterface(InterfaceFlags accepted) {
575558
throw InitStageException(*me(), "no children");
576559

577560
InitStageException exceptions;
578-
InterfaceFlags interface;
579561

562+
bool first = true;
580563
for (const Stage::pointer& child : children()) {
581564
try {
582-
child->pimpl()->pruneInterface(accepted);
565+
auto child_impl = child->pimpl();
566+
child_impl->pruneInterface(accepted);
567+
validateInterfaces(*child_impl, accepted, first);
568+
// initialize push connections of children according to their demands
569+
setChildsPushForwardInterface(*child);
570+
setChildsPushBackwardInterface(*child);
571+
first = false;
583572
} catch (InitStageException& e) {
584573
exceptions.append(e);
585574
continue;
586575
}
587-
588-
InterfaceFlags child_interface = child->pimpl()->requiredInterface();
589-
if (interface == UNKNOWN)
590-
interface = child_interface;
591-
else if ((interface & child_interface) != child_interface) {
592-
boost::format desc("inferred interface of stage '%1%' (%2%/%3%) does not agree with the inferred interface of "
593-
"its siblings (%4%/%5%).");
594-
desc % child->name();
595-
desc % flowSymbol<START_IF_MASK>(child_interface) % flowSymbol<END_IF_MASK>(child_interface);
596-
desc % flowSymbol<START_IF_MASK>(interface) % flowSymbol<END_IF_MASK>(interface);
597-
exceptions.push_back(*me(), desc.str());
598-
}
599-
}
600-
601-
if ((interface & accepted) != accepted) {
602-
boost::format desc("required interface (%1%/%2%) does not match children (%3%/%4%).");
603-
desc % flowSymbol<START_IF_MASK>(accepted) % flowSymbol<END_IF_MASK>(accepted);
604-
desc % flowSymbol<START_IF_MASK>(interface) % flowSymbol<END_IF_MASK>(interface);
605-
exceptions.push_back(*me(), desc.str());
606576
}
607577

608578
if (exceptions)
609579
throw exceptions;
610580

611581
// States received by the container need to be copied to all children's pull interfaces.
612-
if (interface & READS_START)
582+
if (accepted & READS_START)
613583
starts().reset(new Interface([this](Interface::iterator external, bool updated) {
614584
this->onNewExternalState(Interface::FORWARD, external, updated);
615585
}));
616-
if (interface & READS_END)
586+
if (accepted & READS_END)
617587
ends().reset(new Interface([this](Interface::iterator external, bool updated) {
618588
this->onNewExternalState(Interface::BACKWARD, external, updated);
619589
}));
620590

621-
// initialize push connections of children according to their demands
622-
for (const Stage::pointer& stage : children()) {
623-
setChildsPushForwardInterface(*stage);
624-
setChildsPushBackwardInterface(*stage);
591+
required_interface_ = accepted;
592+
}
593+
594+
void ParallelContainerBasePrivate::validateInterfaces(const StagePrivate& child, InterfaceFlags& external,
595+
bool first) const {
596+
const InterfaceFlags child_interface = child.requiredInterface();
597+
bool valid = true;
598+
for (InterfaceFlags mask : { START_IF_MASK, END_IF_MASK }) {
599+
if ((external & mask) == UNKNOWN)
600+
external |= child_interface & mask;
601+
602+
valid = valid & ((external & mask) == (child_interface & mask));
625603
}
626604

627-
required_interface_ = interface;
605+
if (!valid) {
606+
boost::format desc("interface of '%1%' (%3% %4%) does not match %2% (%5% %6%).");
607+
desc % child.name();
608+
desc % (first ? "external one" : "other children's");
609+
desc % flowSymbol<START_IF_MASK>(child_interface) % flowSymbol<END_IF_MASK>(child_interface);
610+
desc % flowSymbol<START_IF_MASK>(external) % flowSymbol<END_IF_MASK>(external);
611+
throw InitStageException(*me_, desc.str());
612+
}
628613
}
629614

630615
void ParallelContainerBasePrivate::validateConnectivity() const {
631-
InitStageException errors;
632616
InterfaceFlags my_interface = interfaceFlags();
633617

634618
// check that input / output interfaces of all children are handled by my interface
635-
for (const auto& child : children()) {
636-
InterfaceFlags current = child->pimpl()->interfaceFlags();
637-
if ((current & my_interface & START_IF_MASK) != (current & START_IF_MASK))
638-
mismatchingInterface<START_IF_MASK>(errors, *child->pimpl());
639-
if ((current & my_interface & END_IF_MASK) != (current & END_IF_MASK))
640-
mismatchingInterface<END_IF_MASK>(errors, *child->pimpl());
641-
}
619+
for (const auto& child : children())
620+
validateInterfaces(*child->pimpl(), my_interface);
642621

643-
// recursively validate children
644-
try {
645-
ContainerBasePrivate::validateConnectivity();
646-
} catch (InitStageException& e) {
647-
errors.append(e);
648-
}
649-
650-
if (errors)
651-
throw errors;
622+
ContainerBasePrivate::validateConnectivity();
652623
}
653624

654625
void ParallelContainerBasePrivate::onNewExternalState(Interface::Direction dir, Interface::iterator external,

core/src/stage.cpp

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -443,21 +443,6 @@ void PropagatingEitherWayPrivate::pruneInterface(InterfaceFlags accepted) {
443443
initInterface(dir);
444444
}
445445

446-
void PropagatingEitherWayPrivate::validateConnectivity() const {
447-
InterfaceFlags actual = interfaceFlags();
448-
if (actual == UNKNOWN)
449-
throw InitStageException(*me(), "not connected in any direction");
450-
451-
InitStageException errors;
452-
if ((actual & READS_START) && !(actual & WRITES_NEXT_START))
453-
errors.push_back(*me(), "Cannot push forwards");
454-
if ((actual & READS_END) && !(actual & WRITES_PREV_END))
455-
errors.push_back(*me(), "Cannot push backwards");
456-
457-
if (errors)
458-
throw errors;
459-
}
460-
461446
InterfaceFlags PropagatingEitherWayPrivate::requiredInterface() const {
462447
return required_interface_;
463448
}

core/src/task.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,6 @@ void Task::init() {
262262
stages()->init(impl->robot_model_);
263263
// task expects its wrapped child to push to both ends, this triggers interface resolution
264264
stages()->pimpl()->pruneInterface(InterfaceFlags({ GENERATE }));
265-
// and *finally* validate connectivity
266-
stages()->pimpl()->validateConnectivity();
267265

268266
// provide introspection instance to all stages
269267
impl->setIntrospection(impl->introspection_.get());

0 commit comments

Comments
 (0)