Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 90 additions & 2 deletions worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package worker
import (
"context"
"fmt"
"os"
"strings"
"sync"
"time"
Expand All @@ -16,6 +17,14 @@ type Progress interface {
Add(n int) error
}

type TestResult struct {
FilePath string
IsBlocked bool
StatusCode int
Error error
ResponseDetails string // 根据需要添加更多字段
}

type Worker struct {
ctx context.Context
cancel context.CancelFunc
Expand All @@ -27,6 +36,7 @@ type Worker struct {
jobResultDone chan struct{}
result *Result
progressBar Progress
results chan TestResult

addr string // target addr
isHttps bool // is https
Expand Down Expand Up @@ -103,6 +113,7 @@ func NewWorker(
addr: addr,
isHttps: isHttps,
timeout: 1000, // 1000ms
results: make(chan TestResult, len(fileList)),
blockStatusCode: blockStatusCode,

jobs: make(chan *Job),
Expand Down Expand Up @@ -230,7 +241,10 @@ func (w *Worker) runWorker() {
return
}
elap := time.Since(start).Nanoseconds()
(*conn).Close()
err = (*conn).Close()
if err != nil {
return
}
job.Result.Success = true
if strings.HasSuffix(job.FilePath, "white") {
job.Result.IsWhite = true // white case
Expand All @@ -242,11 +256,34 @@ func (w *Worker) runWorker() {
job.Result.IsPass = true
}
job.Result.TimeCost = elap

// 生成 TestResult 并发送到 results 通道
testResult := TestResult{
FilePath: filePath,
IsBlocked: code == w.blockStatusCode,
StatusCode: code,
Error: nil,
}
if job.Result.Err != "" {
testResult.Error = fmt.Errorf(job.Result.Err)
}
w.results <- testResult
}()
}
}

func parseError(errStr string) error {
if errStr == "" {
return nil
}
return fmt.Errorf(errStr)
}

func (w *Worker) processJobResult() {
// 定义两个切片用于存储误报和漏报的测试用例
var falsePositives []TestResult
var falseNegatives []TestResult

for job := range w.jobResult {
if job.Result.Success {
w.result.Success++
Expand All @@ -256,10 +293,24 @@ func (w *Worker) processJobResult() {
w.result.TN++
} else {
w.result.FP++
// 记录误报的测试用例
falsePositives = append(falsePositives, TestResult{
FilePath: job.FilePath,
IsBlocked: true,
StatusCode: job.Result.StatusCode,
Error: parseError(job.Result.Err),
})
}
} else {
if job.Result.IsPass {
w.result.FN++
// 记录漏报的测试用例
falseNegatives = append(falseNegatives, TestResult{
FilePath: job.FilePath,
IsBlocked: false,
StatusCode: job.Result.StatusCode,
Error: parseError(job.Result.Err),
})
} else {
w.result.TP++
}
Expand All @@ -273,6 +324,10 @@ func (w *Worker) processJobResult() {
w.resultCh <- &r
}
}

// 在所有结果处理完毕后,写入误报和漏报的测试用例到CSV文件
w.writeResults("false_positives.csv", falsePositives)
w.writeResults("false_negatives.csv", falseNegatives)
}

func (w *Worker) jobProducer() {
Expand All @@ -295,10 +350,43 @@ func (w *Worker) jobProducer() {

func (w *Worker) generateResult() string {
sb := strings.Builder{}
sb.WriteString(fmt.Sprintf("总样本数量: %d 成功: %d 错误: %d\n", w.result.Total, w.result.Success, (w.result.Total - w.result.Success)))
sb.WriteString(fmt.Sprintf("总样本数量: %d 成功: %d 错误: %d\n", w.result.Total, w.result.Success, w.result.Total-w.result.Success))
sb.WriteString(fmt.Sprintf("检出率: %.2f%% (恶意样本总数: %d , 正确拦截: %d , 漏报放行: %d)\n", float64(w.result.TP)*100/float64(w.result.TP+w.result.FN), w.result.TP+w.result.FN, w.result.TP, w.result.FN))
sb.WriteString(fmt.Sprintf("误报率: %.2f%% (正常样本总数: %d , 正确放行: %d , 误报拦截: %d)\n", float64(w.result.FP)*100/float64(w.result.TN+w.result.FP), w.result.TN+w.result.FP, w.result.TN, w.result.FP))
sb.WriteString(fmt.Sprintf("准确率: %.2f%% (正确拦截 + 正确放行)/样本总数 \n", float64(w.result.TP+w.result.TN)*100/float64(w.result.TP+w.result.TN+w.result.FP+w.result.FN)))
sb.WriteString(fmt.Sprintf("平均耗时: %.2f毫秒\n", float64(w.result.SuccessTimeCost)/float64(w.result.Success)/1000000))
return sb.String()
}

// 新增的写入CSV文件的函数
func (w *Worker) writeResults(filename string, results []TestResult) {
f, err := os.Create(filename)
if err != nil {
fmt.Printf("无法创建 %s: %v\n", filename, err)
return
}
defer func(f *os.File) {
err := f.Close()
if err != nil {

}
}(f)

// 写入 CSV 头部
_, err = f.WriteString("FilePath,IsBlocked,StatusCode,Error\n")
if err != nil {
return
}

for _, res := range results {
errorMsg := ""
if res.Error != nil {
errorMsg = res.Error.Error()
}
line := fmt.Sprintf("\"%s\",%t,%d,\"%s\"\n", res.FilePath, res.IsBlocked, res.StatusCode, errorMsg)
_, err := f.WriteString(line)
if err != nil {
return
}
}
}