From 2f0fe5e67fd1d59a8ef1a19844588fc0c341d95e Mon Sep 17 00:00:00 2001 From: seepine Date: Thu, 18 Apr 2024 10:41:20 +0800 Subject: [PATCH] feat: add multiple thread --- README.md | 8 ++- cmd/blazehttp/main.go | 158 ++++++++++++++++++++++++------------------ 2 files changed, 99 insertions(+), 67 deletions(-) diff --git a/README.md b/README.md index 0bed98b74..7f1407037 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,13 @@ git clone https://github.com/chaitin/blazehttp.git && cd blazehttp ## 🚀 一键运行 ``` bash -bash build.sh && ./build/blazehttp -t http://127.0.0.1:8008 +bash build.sh && ./build/blazehttp -t http://127.0.0.1:8080 +``` + +也可配置 `thread` 参数提高测试速度,默认 2 + +``` +bash build.sh && ./build/blazehttp -t http://127.0.0.1:8080 -thread 4 ``` ## 🕹️ 靶机服务 diff --git a/cmd/blazehttp/main.go b/cmd/blazehttp/main.go index 12f71c49d..546ce0722 100644 --- a/cmd/blazehttp/main.go +++ b/cmd/blazehttp/main.go @@ -12,6 +12,7 @@ import ( "regexp" "sort" "strings" + "sync" "time" blazehttp "github.com/chaitin/blazehttp/http" @@ -27,6 +28,7 @@ var ( glob string // use glob expression to select multi files timeout = 1000 // default 1000 ms mHost string // modify host header + thread = 2 // send request thread requestPerSession bool // send request per session ) @@ -38,6 +40,7 @@ func init() { flag.StringVar(&target, "t", "", "target website, example: http://192.168.0.1:8080") flag.StringVar(&glob, "g", "./testcases/", "glob expression, example: *.http") flag.IntVar(&timeout, "timeout", 1000, "connection timeout, default 1000 ms") + flag.IntVar(&thread, "thread", 2, "request thread, default 2") flag.StringVar(&mHost, "H", "", "modify host header") flag.BoolVar(&requestPerSession, "rps", true, "send request per session") flag.Parse() @@ -141,8 +144,56 @@ func getAllFiles(path string) ([]string, error) { return files, nil } +func work(addr string, isHttps bool, blockStatusCode int, f string) (bool, bool, int64, error) { + req := new(blazehttp.Request) + + if err := req.ReadFile(f); err != nil { + return false, false, 0, fmt.Errorf("read request file: %s error: %s", f, err) + } + if mHost != "" { + // 修改host header + req.SetHost(mHost) + } else { + // 不修改会导致域名备案拦截 + req.SetHost(addr) + } + + if requestPerSession { + // one http request one connection + req.SetHeader("Connection", "close") + } + + req.CalculateContentLength() + + start := time.Now() + conn := connect(addr, isHttps, timeout) + if conn == nil { + return false, false, 0, fmt.Errorf("connect to %s failed", addr) + } + nWrite, err := req.WriteTo(*conn) + if err != nil { + return false, false, 0, fmt.Errorf("send request poc: %s length: %d error: %s", f, nWrite, err) + } + + rsp := new(blazehttp.Response) + if err = rsp.ReadConn(*conn); err != nil { + return false, false, 0, fmt.Errorf("read poc file: %s response, error: %s", f, err) + } + (*conn).Close() + isWhite := false // black case + if strings.HasSuffix(f, "white") { + isWhite = true // white case + } + isPass := true + code := rsp.GetStatusCode() + if code == blockStatusCode { + isPass = false + } + return isWhite, isPass, time.Since(start).Nanoseconds(), nil +} + func main() { - // mcl := true + isHttps := false addr := target @@ -193,74 +244,49 @@ func main() { fmt.Println("目标网站未开启waf") os.Exit(1) } - for _, f := range fileList { - _ = bar.Add(1) - req := new(blazehttp.Request) - if err = req.ReadFile(f); err != nil { - fmt.Printf("read request file: %s error: %s\n", f, err) - continue - } - if mHost != "" { - // 修改host header - req.SetHost(mHost) - } else { - // 不修改会导致域名备案拦截 - req.SetHost(addr) - } - - if requestPerSession { - // one http request one connection - req.SetHeader("Connection", "close") - } - - req.CalculateContentLength() - - start := time.Now() - conn := connect(addr, isHttps, timeout) - if conn == nil { - fmt.Printf("connect to %s failed!\n", addr) - continue - } - nWrite, err := req.WriteTo(*conn) - if err != nil { - fmt.Printf("send request poc: %s length: %d error: %s", f, nWrite, err) - continue - } - - rsp := new(blazehttp.Response) - if err = rsp.ReadConn(*conn); err != nil { - fmt.Printf("read poc file: %s response, error: %s", f, err) - continue - } - elap = append(elap, time.Since(start).Nanoseconds()) - (*conn).Close() - success++ - isWhite := false // black case - if strings.HasSuffix(f, "white") { - isWhite = true // white case - } + var mutex sync.Mutex + concurrency := thread + wg := &sync.WaitGroup{} + jobs := make(chan string, len(fileList)) - isPass := true - code := rsp.GetStatusCode() - if code == blockStatusCode { - isPass = false - } - - if isWhite { - if isPass { - TN += 1 - } else { - FP += 1 - } - } else { - if isPass { - FN += 1 - } else { - TP += 1 + for i := 0; i < concurrency; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for f := range jobs { + isWhite, isPass, time, err := work(addr, isHttps, blockStatusCode, f) + mutex.Lock() + bar.Add(1) + if err != nil { + fmt.Printf("%s\n", err) + continue + } + elap = append(elap, time) + success++ + if isWhite { + if isPass { + TN += 1 + } else { + FP += 1 + } + } else { + if isPass { + FN += 1 + } else { + TP += 1 + } + } + mutex.Unlock() } - } + }() + } + + for _, f := range fileList { + jobs <- f } + close(jobs) + wg.Wait() fmt.Printf("总样本数量: %d 成功: %d 错误: %d\n", len(fileList), success, (len(fileList) - success)) fmt.Printf("检出率: %.2f%% (恶意样本总数: %d , 正确拦截: %d , 漏报放行: %d)\n", float64(TP)*100/float64(TP+FN), TP+FN, TP, FN) @@ -273,5 +299,5 @@ func main() { for _, v := range elap { sum += v } - fmt.Printf("平均耗时: %.2f毫秒\n", float64(sum)/float64(all)/1000000) + fmt.Printf("平均耗时: %.2f 毫秒\n", float64(sum)/float64(all)/1000000) }