11#!/usr/bin/env python3
2- from asyncio import get_event_loop , gather
2+ from asyncio import get_event_loop , gather , Semaphore
33from collections import defaultdict
44from datetime import datetime
55from json import dumps
66from logging import getLogger , basicConfig , INFO
7- from os import environ
7+ from os import environ , cpu_count
88from pathlib import Path
99from random import choice
1010from sys import path
2424from s3recon .mongodb import MongoDB , Hit , Access
2525
2626filterwarnings ("ignore" , category = InsecureRequestWarning )
27+ cpus = cpu_count () or 1
2728
2829logger = getLogger (__name__ )
2930
@@ -52,21 +53,22 @@ def bucket_exists(url, timeout):
5253 return exists , public
5354
5455
55- def find_bucket (url , timeout , db ):
56- exists , public = bucket_exists (url , timeout )
56+ async def find_bucket (url , timeout , db , sem ):
57+ async with sem :
58+ exists , public = bucket_exists (url , timeout )
5759
58- if exists :
59- access = Access .PUBLIC if public else Access .PRIVATE
60- access_key = repr (access )
61- access_word = str (access ).upper ()
62- logger .info (f"{ access_key } { access_word } { url } " )
60+ if exists :
61+ access = Access .PUBLIC if public else Access .PRIVATE
62+ access_key = repr (access )
63+ access_word = str (access ).upper ()
64+ logger .info (f"{ access_key } { access_word } { url } " )
6365
64- hit = Hit (url , access )
65- if db and hit .is_valid ():
66- db .update ({"url" : url }, dict (hit ))
67- return Hit (url , access )
66+ hit = Hit (url , access )
67+ if db and hit .is_valid ():
68+ db .update ({"url" : url }, dict (hit ))
69+ return Hit (url , access )
6870
69- return None
71+ return None
7072
7173
7274def collect_results (hits ):
@@ -105,7 +107,7 @@ def json_output_template(key, total, hits, exclude):
105107 return {} if exclude else {key : {"total" : total , "hits" : hits }}
106108
107109
108- def main (words , timeout , output , use_db , only_public ):
110+ def main (words , timeout , concurrency , output , use_db , only_public ):
109111 start = datetime .now ()
110112 loop = get_event_loop ()
111113
@@ -129,16 +131,16 @@ def main(words, timeout, output, use_db, only_public):
129131 for env in environments
130132 }
131133
134+ db = MongoDB (host = database ["host" ], port = database ["port" ]) if use_db else None
135+ sem = Semaphore (concurrency )
136+
132137 tasks = gather (
133138 * [
134- loop .run_in_executor (
135- None ,
136- find_bucket ,
139+ find_bucket (
137140 url ,
138141 timeout ,
139- MongoDB (host = database ["host" ], port = database ["port" ])
140- if use_db
141- else None ,
142+ db ,
143+ sem
142144 )
143145 for url in url_list
144146 ]
@@ -198,6 +200,14 @@ def cli():
198200 parser .add_argument (
199201 "-v" , "--version" , action = "version" , version = f"%(prog)s { __version__ } "
200202 )
203+ parser .add_argument (
204+ "-c" ,
205+ "--concurrency" ,
206+ type = int ,
207+ metavar = "num" ,
208+ default = cpus ,
209+ help = f"maximum <num> of concurrent requests (default: { cpus } )" ,
210+ )
201211 # parser.add_argument("words", nargs="?", type=argparse.FileType("r"), default=stdin, help="list of words to permute")
202212 parser .add_argument (
203213 "word_list" ,
@@ -210,10 +220,11 @@ def cli():
210220 output = args .output
211221 db = args .db
212222 timeout = args .timeout
223+ concurrency = args .concurrency
213224 public = args .public
214225 words = {l .strip () for f in args .word_list for l in f }
215226
216- main (words = words , timeout = timeout , output = output , use_db = db , only_public = public )
227+ main (words = words , timeout = timeout , concurrency = max ( 1 , concurrency ), output = output , use_db = db , only_public = public )
217228
218229
219230if __name__ == "__main__" :
0 commit comments