From 1a4d1ad3d52be266fa880264fe1893fe4edfe29c Mon Sep 17 00:00:00 2001 From: Stephane Thiell Date: Thu, 7 Aug 2025 22:19:29 -0700 Subject: [PATCH] Tree: implement TreeWorker.abort() (#229) Now that the TreeWorker code allows us to abort commands per gateway channel, we can implement a more general TreeWorker.abort() method. This method aborts all direct workers and also commands handled via gateways. Closes #229. --- lib/ClusterShell/Communication.py | 1 - lib/ClusterShell/Propagation.py | 1 - lib/ClusterShell/Worker/Tree.py | 14 ++- tests/TreeWorkerTest.py | 166 ++++++++++++++++++++++++++++++ 4 files changed, 176 insertions(+), 6 deletions(-) diff --git a/lib/ClusterShell/Communication.py b/lib/ClusterShell/Communication.py index 5cf50707..70290257 100644 --- a/lib/ClusterShell/Communication.py +++ b/lib/ClusterShell/Communication.py @@ -218,7 +218,6 @@ def ev_start(self, worker): def ev_read(self, worker, node, sname, msg): """channel has data to read""" # sname can be either SNAME_READER or self.SNAME_ERROR - if sname == self.SNAME_ERROR: if self.initiator: self.recv(StdErrMessage(node, msg)) diff --git a/lib/ClusterShell/Propagation.py b/lib/ClusterShell/Propagation.py index 3ad22cf4..a53ae010 100644 --- a/lib/ClusterShell/Propagation.py +++ b/lib/ClusterShell/Propagation.py @@ -251,7 +251,6 @@ def recv(self, msg): """process incoming messages""" self.logger.debug("recv: %s", msg) if msg.type == EndMessage.ident: - #??#self.ptree.notify_close() self.logger.debug("got EndMessage; closing") self._close() elif msg.type == StdErrMessage.ident and msg.srcid == 0: diff --git a/lib/ClusterShell/Worker/Tree.py b/lib/ClusterShell/Worker/Tree.py index 9660a09e..ddb42873 100644 --- a/lib/ClusterShell/Worker/Tree.py +++ b/lib/ClusterShell/Worker/Tree.py @@ -429,10 +429,14 @@ def _on_remote_node_msgline(self, node, msg, sname, gateway): def _on_remote_node_close(self, node, rc, gateway): """remote node closing with return code""" - DistantWorker._on_node_close(self, node, rc) self.logger.debug("_on_remote_node_close %s %s via gw %s rc=%s", node, self._close_count, gateway, rc) + # this must be done first to avoid recursion via event handlers + self.gwtargets[str(gateway)].remove(node) + + DistantWorker._on_node_close(self, node, rc) + # finalize rcopy: extract tar data if self.source and self.reverse: if node in self._rcopy_bufs: @@ -458,7 +462,6 @@ def _on_remote_node_close(self, node, rc, gateway): else: self.logger.debug("no rcopy buffer received from %s", node) - self.gwtargets[str(gateway)].remove(node) self._close_count += 1 self._check_fini(gateway) @@ -578,8 +581,11 @@ def _gateway_abort(self, gateway): def abort(self): """Abort processing any action by this worker.""" - # Not yet supported by TreeWorker - raise NotImplementedError("see github issue #229") + self.logger.debug("abort %s" % self) + for worker in self.workers: + worker.abort() + for gateway in self.gwtargets.copy(): + self._gateway_abort(gateway) # TreeWorker's former name (deprecated as of 1.8) diff --git a/tests/TreeWorkerTest.py b/tests/TreeWorkerTest.py index 3c07d806..469e1e79 100644 --- a/tests/TreeWorkerTest.py +++ b/tests/TreeWorkerTest.py @@ -547,6 +547,172 @@ def test_tree_worker_name_compat(self): """test TreeWorker former name (WorkerTree)""" self.assertEqual(TreeWorker, WorkerTree) + def test_tree_run_abort_on_start(self): + """test tree run abort on ev_start""" + class TEventAbortOnStartHandler(TEventHandler): + """Test Event Abort On Start Handler""" + + def __init__(self, testcase): + TEventHandler.__init__(self) + self.testcase = testcase + + def ev_start(self, worker): + TEventHandler.ev_start(self, worker) + worker.abort() + + def ev_hup(self, worker, node, rc): + TEventHandler.ev_hup(self, worker, node, rc) + self.testcase.assertEqual(rc, os.EX_PROTOCOL) + + teh = TEventAbortOnStartHandler(self) + self.task.run('echo Lorem Ipsum', nodes=NODE_DISTANT, handler=teh) + self.assertEqual(teh.ev_start_cnt, 1) + #self.assertEqual(teh.ev_pickup_cnt, 0) # XXX to be improved + self.assertEqual(teh.ev_read_cnt, 0) + self.assertEqual(teh.ev_written_cnt, 0) + self.assertEqual(teh.ev_hup_cnt, 1) + self.assertEqual(teh.ev_timedout_cnt, 0) + self.assertEqual(teh.ev_close_cnt, 1) + self.assertEqual(teh.last_read, None) + + def test_tree_run_abort_on_pickup(self): + """test tree run abort on ev_pickup""" + class TEventAbortOnPickupHandler(TEventHandler): + """Test Event Abort On Pickup Handler""" + + def __init__(self, testcase): + TEventHandler.__init__(self) + self.testcase = testcase + + def ev_pickup(self, worker, node): + TEventHandler.ev_pickup(self, worker, node) + worker.abort() + + def ev_hup(self, worker, node, rc): + TEventHandler.ev_hup(self, worker, node, rc) + self.testcase.assertEqual(rc, os.EX_PROTOCOL) + + teh = TEventAbortOnPickupHandler(self) + self.task.run('echo Lorem Ipsum', nodes=NODE_DISTANT, handler=teh) + self.assertEqual(teh.ev_start_cnt, 1) + self.assertEqual(teh.ev_pickup_cnt, 1) + self.assertEqual(teh.ev_read_cnt, 0) + self.assertEqual(teh.ev_written_cnt, 0) + self.assertEqual(teh.ev_hup_cnt, 1) + self.assertEqual(teh.ev_timedout_cnt, 0) + self.assertEqual(teh.ev_close_cnt, 1) + self.assertEqual(teh.last_read, None) + + def test_tree_run_abort_on_read(self): + """test tree run abort on ev_read""" + class TEventAbortOnReadHandler(TEventHandler): + """Test Event Abort On Start Handler""" + + def __init__(self, testcase): + TEventHandler.__init__(self) + self.testcase = testcase + + def ev_read(self, worker, node, sname, msg): + TEventHandler.ev_read(self, worker, node, sname, msg) + worker.abort() + + def ev_hup(self, worker, node, rc): + TEventHandler.ev_hup(self, worker, node, rc) + self.testcase.assertEqual(rc, os.EX_PROTOCOL) + + teh = TEventAbortOnReadHandler(self) + self.task.run('echo Lorem Ipsum', nodes=NODE_DISTANT, handler=teh) + self.assertEqual(teh.ev_start_cnt, 1) + self.assertEqual(teh.ev_pickup_cnt, 1) + self.assertEqual(teh.ev_read_cnt, 1) + self.assertEqual(teh.ev_written_cnt, 0) + self.assertEqual(teh.ev_hup_cnt, 1) + self.assertEqual(teh.ev_timedout_cnt, 0) + self.assertEqual(teh.ev_close_cnt, 1) + self.assertEqual(teh.last_read, b'Lorem Ipsum') + + def test_tree_run_abort_on_hup(self): + """test tree run abort on ev_hup""" + class TEventAbortOnHupHandler(TEventHandler): + """Test Event Abort On Hup Handler""" + + def __init__(self, testcase): + TEventHandler.__init__(self) + self.testcase = testcase + + def ev_hup(self, worker, node, rc): + TEventHandler.ev_hup(self, worker, node, rc) + worker.abort() + + teh = TEventAbortOnHupHandler(self) + self.task.run('echo Lorem Ipsum', nodes=NODE_DISTANT, handler=teh) + self.assertEqual(teh.ev_start_cnt, 1) + self.assertEqual(teh.ev_pickup_cnt, 1) + self.assertEqual(teh.ev_read_cnt, 1) + self.assertEqual(teh.ev_written_cnt, 0) + self.assertEqual(teh.ev_hup_cnt, 1) + self.assertEqual(teh.ev_timedout_cnt, 0) + self.assertEqual(teh.ev_close_cnt, 1) + self.assertEqual(teh.last_read, b'Lorem Ipsum') + + def test_tree_run_abort_on_close(self): + """test tree run abort on ev_close""" + class TEventAbortOnCloseHandler(TEventHandler): + """Test Event Abort On Close Handler""" + + def __init__(self, testcase): + TEventHandler.__init__(self) + self.testcase = testcase + + def ev_close(self, worker, timedout): + TEventHandler.ev_close(self, worker, timedout) + self.testcase.assertEqual(type(worker), TreeWorker) + worker.abort() + + teh = TEventAbortOnCloseHandler(self) + self.task.run('echo Lorem Ipsum', nodes=NODE_DISTANT, handler=teh) + self.assertEqual(teh.ev_start_cnt, 1) + self.assertEqual(teh.ev_pickup_cnt, 1) + self.assertEqual(teh.ev_read_cnt, 1) + self.assertEqual(teh.ev_written_cnt, 0) + self.assertEqual(teh.ev_hup_cnt, 1) + self.assertEqual(teh.ev_timedout_cnt, 0) + self.assertEqual(teh.ev_close_cnt, 1) + self.assertEqual(teh.last_read, b'Lorem Ipsum') + + def test_tree_run_abort_on_timer(self): + """test tree run abort on timer""" + class TEventAbortOnTimerHandler(TEventHandler): + """Test Event Abort On Timer Handler""" + + def __init__(self, testcase): + TEventHandler.__init__(self) + self.testcase = testcase + self.worker = None + + def ev_timer(self, timer): + self.worker.abort() + + def ev_hup(self, worker, node, rc): + TEventHandler.ev_hup(self, worker, node, rc) + self.testcase.assertEqual(rc, os.EX_PROTOCOL) + + # Test abort from a timer's event handler + teh = TEventAbortOnTimerHandler(self) + # channel might take some time to set up; hard to time it + # we play it safe here and don't expect anything to read + teh.worker = self.task.shell('sleep 10; echo Lorem Ipsum', nodes=NODE_DISTANT, handler=teh) + timer1 = self.task.timer(3, handler=teh) + self.task.run() + self.assertEqual(teh.ev_start_cnt, 1) + self.assertEqual(teh.ev_pickup_cnt, 1) + self.assertEqual(teh.ev_read_cnt, 0) + self.assertEqual(teh.ev_written_cnt, 0) + self.assertEqual(teh.ev_hup_cnt, 1) + self.assertEqual(teh.ev_timedout_cnt, 0) + self.assertEqual(teh.ev_close_cnt, 1) + self.assertEqual(teh.last_read, None) + @unittest.skipIf(HOSTNAME == 'localhost', "does not work with hostname set to 'localhost'") class TreeWorkerGW2Test(TreeWorkerTestBase):