diff --git a/worker/worker.go b/worker/worker.go index d11dd56f..8e7e22a5 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -3,6 +3,7 @@ package worker import ( "context" "fmt" + "os" "strings" "sync" "time" @@ -16,6 +17,14 @@ type Progress interface { Add(n int) error } +type TestResult struct { + FilePath string + IsBlocked bool + StatusCode int + Error error + ResponseDetails string // 根据需要添加更多字段 +} + type Worker struct { ctx context.Context cancel context.CancelFunc @@ -27,6 +36,7 @@ type Worker struct { jobResultDone chan struct{} result *Result progressBar Progress + results chan TestResult addr string // target addr isHttps bool // is https @@ -103,6 +113,7 @@ func NewWorker( addr: addr, isHttps: isHttps, timeout: 1000, // 1000ms + results: make(chan TestResult, len(fileList)), blockStatusCode: blockStatusCode, jobs: make(chan *Job), @@ -230,7 +241,10 @@ func (w *Worker) runWorker() { return } elap := time.Since(start).Nanoseconds() - (*conn).Close() + err = (*conn).Close() + if err != nil { + return + } job.Result.Success = true if strings.HasSuffix(job.FilePath, "white") { job.Result.IsWhite = true // white case @@ -242,11 +256,34 @@ func (w *Worker) runWorker() { job.Result.IsPass = true } job.Result.TimeCost = elap + + // 生成 TestResult 并发送到 results 通道 + testResult := TestResult{ + FilePath: filePath, + IsBlocked: code == w.blockStatusCode, + StatusCode: code, + Error: nil, + } + if job.Result.Err != "" { + testResult.Error = fmt.Errorf(job.Result.Err) + } + w.results <- testResult }() } } +func parseError(errStr string) error { + if errStr == "" { + return nil + } + return fmt.Errorf(errStr) +} + func (w *Worker) processJobResult() { + // 定义两个切片用于存储误报和漏报的测试用例 + var falsePositives []TestResult + var falseNegatives []TestResult + for job := range w.jobResult { if job.Result.Success { w.result.Success++ @@ -256,10 +293,24 @@ func (w *Worker) processJobResult() { w.result.TN++ } else { w.result.FP++ + // 记录误报的测试用例 + falsePositives = append(falsePositives, TestResult{ + FilePath: job.FilePath, + IsBlocked: true, + StatusCode: job.Result.StatusCode, + Error: parseError(job.Result.Err), + }) } } else { if job.Result.IsPass { w.result.FN++ + // 记录漏报的测试用例 + falseNegatives = append(falseNegatives, TestResult{ + FilePath: job.FilePath, + IsBlocked: false, + StatusCode: job.Result.StatusCode, + Error: parseError(job.Result.Err), + }) } else { w.result.TP++ } @@ -273,6 +324,10 @@ func (w *Worker) processJobResult() { w.resultCh <- &r } } + + // 在所有结果处理完毕后,写入误报和漏报的测试用例到CSV文件 + w.writeResults("false_positives.csv", falsePositives) + w.writeResults("false_negatives.csv", falseNegatives) } func (w *Worker) jobProducer() { @@ -295,10 +350,43 @@ func (w *Worker) jobProducer() { func (w *Worker) generateResult() string { sb := strings.Builder{} - sb.WriteString(fmt.Sprintf("总样本数量: %d 成功: %d 错误: %d\n", w.result.Total, w.result.Success, (w.result.Total - w.result.Success))) + sb.WriteString(fmt.Sprintf("总样本数量: %d 成功: %d 错误: %d\n", w.result.Total, w.result.Success, w.result.Total-w.result.Success)) sb.WriteString(fmt.Sprintf("检出率: %.2f%% (恶意样本总数: %d , 正确拦截: %d , 漏报放行: %d)\n", float64(w.result.TP)*100/float64(w.result.TP+w.result.FN), w.result.TP+w.result.FN, w.result.TP, w.result.FN)) sb.WriteString(fmt.Sprintf("误报率: %.2f%% (正常样本总数: %d , 正确放行: %d , 误报拦截: %d)\n", float64(w.result.FP)*100/float64(w.result.TN+w.result.FP), w.result.TN+w.result.FP, w.result.TN, w.result.FP)) sb.WriteString(fmt.Sprintf("准确率: %.2f%% (正确拦截 + 正确放行)/样本总数 \n", float64(w.result.TP+w.result.TN)*100/float64(w.result.TP+w.result.TN+w.result.FP+w.result.FN))) sb.WriteString(fmt.Sprintf("平均耗时: %.2f毫秒\n", float64(w.result.SuccessTimeCost)/float64(w.result.Success)/1000000)) return sb.String() } + +// 新增的写入CSV文件的函数 +func (w *Worker) writeResults(filename string, results []TestResult) { + f, err := os.Create(filename) + if err != nil { + fmt.Printf("无法创建 %s: %v\n", filename, err) + return + } + defer func(f *os.File) { + err := f.Close() + if err != nil { + + } + }(f) + + // 写入 CSV 头部 + _, err = f.WriteString("FilePath,IsBlocked,StatusCode,Error\n") + if err != nil { + return + } + + for _, res := range results { + errorMsg := "" + if res.Error != nil { + errorMsg = res.Error.Error() + } + line := fmt.Sprintf("\"%s\",%t,%d,\"%s\"\n", res.FilePath, res.IsBlocked, res.StatusCode, errorMsg) + _, err := f.WriteString(line) + if err != nil { + return + } + } +}