diff --git a/mozilla_aws_cli/listener.py b/mozilla_aws_cli/listener.py index b3adacb..666ac95 100644 --- a/mozilla_aws_cli/listener.py +++ b/mozilla_aws_cli/listener.py @@ -1,15 +1,17 @@ from __future__ import print_function + import errno import logging import os.path import socket +import threading import time - -from flask import Flask, jsonify, request, send_from_directory from operator import itemgetter -from .utils import exit_sigint, get_alias +from flask import Flask, jsonify, request, send_from_directory +from werkzeug.serving import make_server +from .utils import get_alias # These ports must be configured in the IdP's allowed callback URL list # TODO: Move this to the CLI / config section @@ -18,6 +20,8 @@ STATIC_DIR = os.path.join(os.path.dirname( os.path.realpath(__file__)), "static") +is_done = threading.Event() +server = None app = Flask(__name__) logger = logging.getLogger(__name__) login = { @@ -113,17 +117,17 @@ def get_heartbeat(): "status_code": 500, }) - start = time.time() - while time.time() - start < 30: - if login.last_state_check is None: - pass - elif (time.time() - login.last_state_check > - login.max_sleep_no_state_check): - logger.error( - "No response from web interface for {} seconds, shutting " - "down.".format(login.max_sleep_no_state_check)) - exit_sigint() - time.sleep(0.5) + # start = time.time() + # while time.time() - start < 30: + # if login.last_state_check is None: + # pass + # elif (time.time() - login.last_state_check > + # login.max_sleep_no_state_check): + # logger.error( + # "No response from web interface for {} seconds, shutting " + # "down.".format(login.max_sleep_no_state_check)) + # exit_sigint() + # time.sleep(0.5) return jsonify({ "result": "heartbeat_done", "status_code": 200, @@ -217,18 +221,37 @@ def handle_oidc_redirect_callback(): @app.route("/shutdown", methods=["GET"]) def handle_shutdown(): logger.debug("Shutting down Flask") - exit_sigint() - - # this is down to prevent race conditions + is_done.set() + threading.Timer(2, server.shutdown).start() return jsonify({ "result": "shutdown", "status_code": 200, }) -def listen(login): +class ServerThread(threading.Thread): + + def __init__(self, app): + threading.Thread.__init__(self) + self.srv = make_server('127.0.0.1', port, app) + self.ctx = app.app_context() + self.ctx.push() + + def run(self): + self.srv.serve_forever() + logger.debug("Flask done") + + def shutdown(self): + try: + self.srv.shutdown() + except Exception: + pass + + +def listen(login_): # set the global callback - globals()["login"] = login + global server, login + login = login_ debug = True if logger.level == logging.DEBUG else False @@ -237,6 +260,7 @@ def listen(login): os.environ["WERKZEUG_RUN_MAIN"] = "true" logging.getLogger("werkzeug").setLevel(logging.ERROR) - app.run(port=port, debug=debug) - + server = ServerThread(app) + server.start() + is_done.wait() return port