11#!/usr/bin/env python3
22
3- from gevent .pool import Pool
4- from gevent import sleep
53from cme .logger import setup_logger , setup_debug_logger , CMEAdapter
64from cme .helpers .logger import highlight
75from cme .helpers .misc import identify_target_file
1412from cme .servers .http import CMEServer
1513from cme .first_run import first_run_setup
1614from cme .context import Context
15+ from concurrent .futures import ThreadPoolExecutor
1716from pprint import pformat
17+ from decimal import Decimal
18+ import time
19+ import asyncio
20+ import aioconsole
21+ import functools
1822import configparser
1923import cme .helpers .powershell as powershell
2024import cme
2630import sys
2731import logging
2832
33+ setup_logger ()
34+ logger = CMEAdapter ()
2935
30- def main ():
36+ async def monitor_threadpool (pool , targets ):
37+ logging .debug ('Started thread poller' )
38+
39+ while True :
40+ try :
41+ text = await aioconsole .ainput ("" )
42+ if text == "" :
43+ pool_size = pool ._work_queue .qsize ()
44+ finished_threads = len (targets ) - pool_size
45+ percentage = Decimal (finished_threads ) / Decimal (len (targets )) * Decimal (100 )
46+ logger .info (f"completed: { percentage :.2f} % ({ finished_threads } /{ len (targets )} )" )
47+ except asyncio .CancelledError :
48+ logging .debug ("Stopped thread poller" )
49+ break
50+
51+ async def run_protocol (loop , protocol_obj , args , db , target , jitter ):
52+ try :
53+ if jitter :
54+ value = random .choice (range (jitter [0 ], jitter [1 ]))
55+ logging .debug (f"Doin' the jitterbug for { value } second(s)" )
56+ await asyncio .sleep (value )
57+
58+ thread = loop .run_in_executor (
59+ None ,
60+ functools .partial (
61+ protocol_obj ,
62+ args ,
63+ db ,
64+ str (target )
65+ )
66+ )
67+
68+ await asyncio .wait_for (
69+ thread ,
70+ timeout = args .timeout
71+ )
72+
73+ except asyncio .TimeoutError :
74+ logging .debug ("Thread exceeded timeout" )
75+ except asyncio .CancelledError :
76+ logging .debug ("Stopping thread" )
77+ thread .cancel ()
78+
79+ async def start_threadpool (protocol_obj , args , db , targets , jitter ):
80+ pool = ThreadPoolExecutor (max_workers = args .threads + 1 )
81+ loop = asyncio .get_running_loop ()
82+ loop .set_default_executor (pool )
83+
84+ monitor_task = asyncio .create_task (
85+ monitor_threadpool (pool , targets )
86+ )
87+
88+ jobs = [
89+ run_protocol (
90+ loop ,
91+ protocol_obj ,
92+ args ,
93+ db ,
94+ target ,
95+ jitter
96+ )
97+ for target in targets
98+ ]
3199
32- setup_logger ()
33- logger = CMEAdapter ()
100+ try :
101+ logging .debug ("Running" )
102+ await asyncio .gather (* jobs )
103+ except asyncio .CancelledError :
104+ print ('\n ' )
105+ logger .info ("Shutting down, please wait..." )
106+ logging .debug ("Cancelling scan" )
107+ finally :
108+ monitor_task .cancel ()
109+ pool .shutdown (wait = True )
110+
111+ def main ():
34112 first_run_setup (logger )
35113
36114 args = gen_cli_args ()
@@ -191,27 +269,14 @@ def main():
191269 setattr (protocol_object , 'server' , module_server .server )
192270
193271 try :
194- '''
195- Open all the greenlet (as supposed to redlet??) threads
196- Whoever came up with that name has a fetish for traffic lights
197- '''
198- pool = Pool (args .threads )
199- jobs = []
200- for target in targets :
201- jobs .append (pool .spawn (protocol_object , args , db , str (target )))
202-
203- if jitter :
204- value = random .choice (range (jitter [0 ], jitter [1 ]))
205- logging .debug ("Doin' the Jitterbug for {} seconds" .format (value ))
206- sleep (value )
207-
208- for job in jobs :
209- job .join (timeout = args .timeout )
272+ asyncio .run (
273+ start_threadpool (protocol_object , args , db , targets , jitter )
274+ )
210275 except KeyboardInterrupt :
211- pass
212-
213- if module_server :
214- module_server .shutdown ()
276+ logging . debug ( "Got keyboard interrupt" )
277+ finally :
278+ if module_server :
279+ module_server .shutdown ()
215280
216281if __name__ == '__main__' :
217282 main ()
0 commit comments