diff --git a/.github/workflows/issues.yml b/.github/workflows/issues.yml index 2eee1c7e..913a7df8 100644 --- a/.github/workflows/issues.yml +++ b/.github/workflows/issues.yml @@ -1,7 +1,7 @@ name: "Close Stale Issues" on: schedule: - - cron: "0 0 * * 3" + - cron: "0 0 * * 0" workflow_dispatch: jobs: @@ -20,7 +20,7 @@ jobs:

Спасибо! close-issue-message: "Issue был закрыт из-за отсутствия активности." - days-before-stale: 120 + days-before-stale: 30 days-before-close: 7 operations-per-run: 1000 ascending: true diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3013fe95..3e1bc29f 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -78,7 +78,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v6 with: - go-version: "1.25.5" + go-version: "1.26.2" # Install Android NDK only for android rows - name: Setup Android NDK diff --git a/.github/workflows/.golangci.yml b/.golangci.yml similarity index 88% rename from .github/workflows/.golangci.yml rename to .golangci.yml index db48597c..59006559 100644 --- a/.github/workflows/.golangci.yml +++ b/.golangci.yml @@ -1,11 +1,18 @@ version: "2" run: + modules-download-mode: readonly + relative-path-mode: gomod tests: true linters: enable: - bodyclose + - errcheck + - govet + - ineffassign - misspell - revive + - staticcheck + - unused settings: errcheck: check-type-assertions: true @@ -13,12 +20,12 @@ linters: exclude-functions: - (net.PacketConn).WriteTo - (net.Conn).Write + - (net.Conn).SetDeadline - encoding/json.MarshalIndent - (*github.com/pion/dtls/v3.Conn).SetDeadline govet: disable: - fieldalignment - enable-all: true revive: rules: - name: blank-imports @@ -52,10 +59,7 @@ linters: issues: max-issues-per-linter: 0 max-same-issues: 0 - exclude-rules: - - linters: - - errcheck - source: "doRequest|packetPool\\.Get" + formatters: exclusions: generated: lax diff --git a/client/main.go b/client/main.go index 71a86251..b7ec2a95 100644 --- a/client/main.go +++ b/client/main.go @@ -11,6 +11,7 @@ import ( "encoding/base64" "encoding/hex" "encoding/json" + "errors" "flag" "fmt" "io" @@ -21,18 +22,21 @@ import ( neturl "net/url" "os" "os/signal" + "path/filepath" "strconv" "strings" "sync" "sync/atomic" "syscall" "time" + "unicode" fhttp "github.com/bogdanfinn/fhttp" tlsclient "github.com/bogdanfinn/tls-client" "github.com/bogdanfinn/tls-client/profiles" "github.com/bschaatsbergen/dnsdialer" + "github.com/cacggghp/vk-turn-proxy/internal/cliutil" "github.com/cacggghp/vk-turn-proxy/tcputil" "github.com/cbeuw/connutil" "github.com/google/uuid" @@ -62,6 +66,7 @@ var ( activeLocalPeer atomic.Value globalCaptchaLockout atomic.Int64 connectedStreams atomic.Int32 + configuredStreams atomic.Int32 globalAppCancel context.CancelFunc handshakeSem = make(chan struct{}, 3) isDebug bool @@ -70,6 +75,7 @@ var ( ) type captchaSolveMode int +type captchaFailureCountContextKey struct{} const ( captchaSolveModeAuto captchaSolveMode = iota @@ -77,6 +83,62 @@ const ( captchaSolveModeManual ) +type clientOptions struct { + host string + port string + listen string + vklink string + yalink string + peerAddr string + n int + udp bool + direct bool + vlessMode bool + debug bool + manualCaptcha bool +} + +func newClientFlagSet(program string, output io.Writer) (*flag.FlagSet, *clientOptions) { + fs := flag.NewFlagSet(program, flag.ContinueOnError) + fs.SetOutput(output) + + opts := &clientOptions{} + fs.StringVar(&opts.host, "turn", "", "override TURN server ip") + fs.StringVar(&opts.port, "port", "", "override TURN port") + fs.StringVar(&opts.listen, "listen", "127.0.0.1:9000", "listen on ip:port") + fs.StringVar(&opts.vklink, "vk-link", "", "VK calls invite link \"https://vk.com/call/join/...\"") + fs.StringVar(&opts.yalink, "yandex-link", "", "Yandex Telemost invite link \"https://telemost.yandex.ru/j/...\"") + fs.StringVar(&opts.peerAddr, "peer", "", "peer server address (host:port)") + fs.IntVar(&opts.n, "n", 0, "connections to TURN (default 10 for VK, 1 for Yandex)") + fs.BoolVar(&opts.udp, "udp", false, "connect to TURN with UDP") + fs.BoolVar(&opts.direct, "no-dtls", false, "connect without obfuscation. DO NOT USE") + fs.BoolVar(&opts.vlessMode, "vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets") + fs.BoolVar(&opts.debug, "debug", false, "enable debug logging") + fs.BoolVar(&opts.manualCaptcha, "manual-captcha", false, "skip auto captcha solving, use manual mode immediately") + fs.Usage = func() { + cliutil.Fprintf(fs.Output(), "Usage:\n %s -peer -vk-link [flags]\n %s -peer -yandex-link [flags]\n\n", program, program) + cliutil.Fprintln(fs.Output(), "Examples:") + cliutil.Fprintf(fs.Output(), " %s -listen 127.0.0.1:9000 -peer 203.0.113.10:56000 -vk-link https://vk.com/call/join/...\n", program) + cliutil.Fprintf(fs.Output(), " %s -udp -turn 5.255.211.241 -peer 203.0.113.10:56000 -yandex-link https://telemost.yandex.ru/j/... -listen 127.0.0.1:9000\n\n", program) + cliutil.Fprintln(fs.Output(), "Flags:") + fs.PrintDefaults() + } + + return fs, opts +} + +func parseClientOptions(args []string, program string, stdout, stderr io.Writer) (clientOptions, int) { + return cliutil.Parse(args, program, stdout, stderr, newClientFlagSet, func(opts *clientOptions) error { + if opts.peerAddr == "" { + return fmt.Errorf("-peer is required") + } + if (opts.vklink == "") == (opts.yalink == "") { + return fmt.Errorf("exactly one of -vk-link or -yandex-link is required") + } + return nil + }) +} + func captchaSolveModeForAttempt(attempt int, manualOnly bool, enableSliderPOC bool) (captchaSolveMode, bool) { if manualOnly { return captchaSolveModeManual, attempt == 0 @@ -112,6 +174,35 @@ func captchaSolveModeLabel(mode captchaSolveMode) string { } } +func withCaptchaFailureCount(ctx context.Context, count int) context.Context { + if count <= 0 { + return ctx + } + return context.WithValue(ctx, captchaFailureCountContextKey{}, count) +} + +func captchaFailureCountFromContext(ctx context.Context) int { + if ctx == nil { + return 0 + } + count, ok := ctx.Value(captchaFailureCountContextKey{}).(int) + if !ok { + return 0 + } + return count +} + +func captchaLogAttempt(ctx context.Context, solveMode captchaSolveMode, attempt int) int { + displayAttempt := attempt + 1 + if solveMode == captchaSolveModeManual { + bucketAttempt := captchaFailureCountFromContext(ctx) + 1 + if bucketAttempt > displayAttempt { + displayAttempt = bucketAttempt + } + } + return displayAttempt +} + type UDPPacket struct { Data []byte N int @@ -611,24 +702,71 @@ type TurnCredentials struct { } type StreamCredentialsCache struct { - creds TurnCredentials - mutex sync.RWMutex - errorCount atomic.Int32 - lastErrorTime atomic.Int64 + creds TurnCredentials + mutex sync.RWMutex + errorCount atomic.Int32 + lastErrorTime atomic.Int64 + captchaFailures int + disabled bool + disableAnnounced bool + retryAfter time.Time + retryErr error + retryLink string } const ( credentialLifetime = 10 * time.Minute cacheSafetyMargin = 60 * time.Second maxCacheErrors = 3 + maxCaptchaFailures = 2 errorWindow = 10 * time.Second streamsPerCache = 10 + cacheRetryDelay = 60 * time.Second ) +var errCaptchaBucketDisabled = errors.New("CAPTCHA_BUCKET_DISABLED") + func getCacheID(streamID int) int { return streamID / streamsPerCache } +func bucketStreamCount(cacheID int, numStreams int) int { + start := cacheID * streamsPerCache + if start >= numStreams { + return 0 + } + end := start + streamsPerCache + if end > numStreams { + end = numStreams + } + return end - start +} + +func activeConfiguredStreamCount() int { + numStreams := int(configuredStreams.Load()) + if numStreams <= 0 { + return 0 + } + + disabledStreams := 0 + credentialsStore.mu.RLock() + defer credentialsStore.mu.RUnlock() + for cacheID, cache := range credentialsStore.caches { + cache.mutex.RLock() + disabled := cache.disabled + cache.mutex.RUnlock() + if disabled { + disabledStreams += bucketStreamCount(cacheID, numStreams) + } + } + + activeStreams := numStreams - disabledStreams + if activeStreams < 0 { + return 0 + } + return activeStreams +} + func vkDelayRandom(minMs, maxMs int) { ms := minMs + rand.Intn(maxMs-minMs+1) time.Sleep(time.Duration(ms) * time.Millisecond) @@ -702,6 +840,12 @@ func handleAuthError(streamID int) bool { func (c *StreamCredentialsCache) invalidate(streamID int) { c.mutex.Lock() c.creds = TurnCredentials{} + c.captchaFailures = 0 + c.disabled = false + c.disableAnnounced = false + c.retryAfter = time.Time{} + c.retryErr = nil + c.retryLink = "" c.mutex.Unlock() c.errorCount.Store(0) @@ -710,6 +854,33 @@ func (c *StreamCredentialsCache) invalidate(streamID int) { log.Printf("[STREAM %d] [VK Auth] Credentials cache invalidated", streamID) } +func isCacheBucketDisabled(streamID int) bool { + cache := getStreamCache(streamID) + cache.mutex.RLock() + defer cache.mutex.RUnlock() + return cache.disabled +} + +func announceDisabledBucket(streamID int) { + cache := getStreamCache(streamID) + cacheID := getCacheID(streamID) + + cache.mutex.Lock() + if cache.disableAnnounced { + cache.mutex.Unlock() + return + } + cache.disableAnnounced = true + cache.mutex.Unlock() + + log.Printf( + "[VK Auth] Cache bucket %d disabled after %d captcha failures. Continuing with %d streams.", + cacheID, + maxCaptchaFailures, + activeConfiguredStreamCount(), + ) +} + func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dnsdialer.Dialer) (string, string, string, error) { cache := getStreamCache(streamID) cacheID := getCacheID(streamID) @@ -733,19 +904,48 @@ func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dn if cache.creds.Link == link && time.Now().Before(cache.creds.ExpiresAt) { return cache.creds.Username, cache.creds.Password, cache.creds.ServerAddr, nil } + if cache.disabled { + return "", "", "", fmt.Errorf("%w: cache=%d", errCaptchaBucketDisabled, cacheID) + } + if cache.retryErr != nil && cache.retryLink == link && time.Now().Before(cache.retryAfter) { + if isDebug { + log.Printf("[STREAM %d] [VK Auth] Reusing cache=%d captcha cooldown (%v remaining)", streamID, cacheID, time.Until(cache.retryAfter).Truncate(time.Millisecond)) + } + return "", "", "", cache.retryErr + } - user, pass, addr, err := fetchVkCredsSerialized(ctx, link, streamID, dialer) + user, pass, addr, err := fetchVkCredsSerializedFunc(withCaptchaFailureCount(ctx, cache.captchaFailures), link, streamID, dialer) if err != nil { + if strings.Contains(err.Error(), "CAPTCHA_WAIT_REQUIRED") { + cache.captchaFailures++ + cache.retryAfter = time.Now().Add(cacheRetryDelay) + cache.retryErr = err + cache.retryLink = link + if cache.captchaFailures >= maxCaptchaFailures { + cache.disabled = true + cache.retryAfter = time.Time{} + cache.retryErr = nil + cache.retryLink = "" + return "", "", "", fmt.Errorf("%w: cache=%d", errCaptchaBucketDisabled, cacheID) + } + } return "", "", "", err } + cache.captchaFailures = 0 + cache.disabled = false + cache.disableAnnounced = false + cache.retryAfter = time.Time{} + cache.retryErr = nil + cache.retryLink = "" cache.creds = TurnCredentials{Username: user, Password: pass, ServerAddr: addr, ExpiresAt: time.Now().Add(credentialLifetime - cacheSafetyMargin), Link: link} return user, pass, addr, nil } var ( - vkRequestMu sync.Mutex - globalLastVkFetchTime time.Time + fetchVkCredsSerializedFunc = fetchVkCredsSerialized + vkRequestMu sync.Mutex + globalLastVkFetchTime time.Time ) func fetchVkCredsSerialized(ctx context.Context, link string, streamID int, dialer *dnsdialer.Dialer) (string, string, string, error) { @@ -958,40 +1158,23 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede log.Printf("[STREAM %d] [Captcha] Triggering manual captcha fallback...", streamID) manualCtx, manualCancel := context.WithTimeout(ctx, 60*time.Second) - type manualRes struct { - token string - key string - err error + if captchaErr.RedirectURI != "" { + successToken, solveErr = solveCaptchaViaProxy(manualCtx, captchaErr.RedirectURI, dialer) + } else if captchaErr.CaptchaImg != "" { + captchaKey, solveErr = solveCaptchaViaHTTP(manualCtx, captchaErr.CaptchaImg) + } else { + solveErr = fmt.Errorf("no redirect_uri or captcha_img") } - resCh := make(chan manualRes, 1) - - go func() { - var t, k string - var e error - if captchaErr.RedirectURI != "" { - t, e = solveCaptchaViaProxy(captchaErr.RedirectURI, dialer) - } else if captchaErr.CaptchaImg != "" { - k, e = solveCaptchaViaHTTP(captchaErr.CaptchaImg) - } else { - e = fmt.Errorf("no redirect_uri or captcha_img") - } - resCh <- manualRes{t, k, e} - }() - - select { - case res := <-resCh: - successToken = res.token - captchaKey = res.key - solveErr = res.err - case <-manualCtx.Done(): + deadlineExceeded := errors.Is(manualCtx.Err(), context.DeadlineExceeded) + manualCancel() + if solveErr != nil && deadlineExceeded { solveErr = fmt.Errorf("manual captcha timed out after 60s") } - manualCancel() } // If solving failed (auto or manual) or timed out if solveErr != nil { - log.Printf("[STREAM %d] [Captcha] %s failed (attempt %d): %v", streamID, captchaSolveModeLabel(solveMode), attempt+1, solveErr) + log.Printf("[STREAM %d] [Captcha] %s failed (attempt %d): %v", streamID, capitalizeFirstLetter(captchaSolveModeLabel(solveMode)), captchaLogAttempt(ctx, solveMode, attempt), solveErr) nextSolveMode, hasNextSolveMode := captchaSolveModeForAttempt(attempt+1, manualCaptcha, autoCaptchaSliderPOC) if hasNextSolveMode { @@ -1424,6 +1607,7 @@ func dtlsFunc(ctx context.Context, conn net.PacketConn, peer *net.UDPAddr) (net. } if err := dtlsConn.HandshakeContext(ctx1); err != nil { + _ = dtlsConn.Close() return nil, err } return dtlsConn, nil @@ -1436,6 +1620,15 @@ func oneDtlsConnection(ctx context.Context, peer *net.UDPAddr, listenConn net.Pa defer dtlscancel() conn1, conn2 := connutil.AsyncPacketPipe() + pipeConn := conn1 + defer func() { + if pipeConn == nil { + return + } + if closeErr := pipeConn.Close(); closeErr != nil { + log.Printf("[STREAM %d] Failed to close DTLS pipe: %s", streamID, closeErr) + } + }() go func() { for { select { @@ -1449,6 +1642,7 @@ func oneDtlsConnection(ctx context.Context, peer *net.UDPAddr, listenConn net.Pa if err1 != nil { return fmt.Errorf("failed to connect DTLS: %s", err1) } + pipeConn = nil defer func() { if closeErr := dtlsConn.Close(); closeErr != nil { log.Printf("[STREAM %d] failed to close DTLS connection: %s", streamID, closeErr) @@ -1713,6 +1907,10 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD func oneDtlsConnectionLoop(ctx context.Context, peer *net.UDPAddr, listenConn net.PacketConn, inboundChan <-chan *UDPPacket, connchan chan<- net.PacketConn, okchan chan<- struct{}, streamID int) { for { + if isCacheBucketDisabled(streamID) { + announceDisabledBucket(streamID) + return + } select { case <-ctx.Done(): return @@ -1747,6 +1945,10 @@ func oneTurnConnectionLoop(ctx context.Context, turnParams *turnParams, peer *ne go oneTurnConnection(ctx, turnParams, peer, conn2, streamID, c) if err := <-c; err != nil { + if strings.Contains(err.Error(), errCaptchaBucketDisabled.Error()) { + announceDisabledBucket(streamID) + return + } if strings.Contains(err.Error(), "FATAL_CAPTCHA") { log.Printf("[STREAM %d] Fatal manual captcha error. Shutting down application.", streamID) if globalAppCancel != nil { @@ -1784,6 +1986,11 @@ func oneTurnConnectionLoop(ctx context.Context, turnParams *turnParams, peer *ne } func main() { + opts, exitCode := parseClientOptions(os.Args[1:], filepath.Base(os.Args[0]), os.Stdout, os.Stderr) + if exitCode != cliutil.ContinueExecution { + os.Exit(exitCode) + } + ctx, cancel := context.WithCancel(context.Background()) globalAppCancel = cancel defer cancel() @@ -1800,38 +2007,19 @@ func main() { log.Fatalf("Exit...\n") }() - host := flag.String("turn", "", "override TURN server ip") - port := flag.String("port", "", "override TURN port") - listen := flag.String("listen", "127.0.0.1:9000", "listen on ip:port") - vklink := flag.String("vk-link", "", "VK calls invite link \"https://vk.com/call/join/...\"") - yalink := flag.String("yandex-link", "", "Yandex telemost invite link \"https://telemost.yandex.ru/j/...\"") - peerAddr := flag.String("peer", "", "peer server address (host:port)") - n := flag.Int("n", 0, "connections to TURN (default 10 for VK, 1 for Yandex)") - udp := flag.Bool("udp", false, "connect to TURN with UDP") - direct := flag.Bool("no-dtls", false, "connect without obfuscation. DO NOT USE") - vlessMode := flag.Bool("vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets") - debugFlag := flag.Bool("debug", false, "enable debug logging") - manualCaptchaFlag := flag.Bool("manual-captcha", false, "skip auto captcha solving, use manual mode immediately") - flag.Parse() - if *peerAddr == "" { - log.Panicf("Need peer address!") - } - peer, err := net.ResolveUDPAddr("udp", *peerAddr) + peer, err := net.ResolveUDPAddr("udp", opts.peerAddr) if err != nil { panic(err) } - if (*vklink == "") == (*yalink == "") { - log.Panicf("Need either vk-link or yandex-link!") - } - isDebug = *debugFlag - manualCaptcha = *manualCaptchaFlag + isDebug = opts.debug + manualCaptcha = opts.manualCaptcha autoCaptchaSliderPOC = !manualCaptcha var link string var getCreds getCredsFunc - if *vklink != "" { - parts := strings.Split(*vklink, "join/") + if opts.vklink != "" { + parts := strings.Split(opts.vklink, "join/") link = parts[len(parts)-1] dialer := dnsdialer.New( @@ -1843,17 +2031,17 @@ func main() { getCreds = func(ctx context.Context, s string, streamID int) (string, string, string, error) { return getVkCredsCached(ctx, s, streamID, dialer) } - if *n <= 0 { - *n = 10 + if opts.n <= 0 { + opts.n = 10 } } else { - parts := strings.Split(*yalink, "j/") + parts := strings.Split(opts.yalink, "j/") link = parts[len(parts)-1] getCreds = func(ctx context.Context, s string, streamID int) (string, string, string, error) { return getYandexCreds(s) } - if *n <= 0 { - *n = 1 + if opts.n <= 0 { + opts.n = 1 } } if idx := strings.IndexAny(link, "/?#"); idx != -1 { @@ -1861,19 +2049,19 @@ func main() { } params := &turnParams{ - host: *host, - port: *port, + host: opts.host, + port: opts.port, link: link, - udp: *udp, + udp: opts.udp, getCreds: getCreds, } - if *vlessMode { - runVLESSMode(ctx, params, peer, *listen, *n) + if opts.vlessMode { + runVLESSMode(ctx, params, peer, opts.listen, opts.n) return } - listenConn, err := net.ListenPacket("udp", *listen) + listenConn, err := net.ListenPacket("udp", opts.listen) if err != nil { log.Panicf("Failed to listen: %s", err) } @@ -1883,10 +2071,11 @@ func main() { } }) - numStreams := *n + numStreams := opts.n if numStreams <= 0 { numStreams = 1 } + configuredStreams.Store(int32(numStreams)) // Shared Worker Pool Queue for Aggregation inboundChan := make(chan *UDPPacket, 2000) @@ -1930,21 +2119,24 @@ func main() { wg1 := sync.WaitGroup{} t := time.Tick(200 * time.Millisecond) - if *direct { + if opts.direct { log.Panicf("Direct mode not supported with dispatcher") } okchan := make(chan struct{}) connchan := make(chan net.PacketConn) + streamCtx, streamCancel := context.WithCancel(ctx) wg1.Add(1) go func() { defer wg1.Done() - oneDtlsConnectionLoop(ctx, peer, listenConn, inboundChan, connchan, okchan, 1) + defer streamCancel() + oneDtlsConnectionLoop(streamCtx, peer, listenConn, inboundChan, connchan, okchan, 1) }() wg1.Add(1) go func() { defer wg1.Done() - oneTurnConnectionLoop(ctx, params, peer, connchan, t, 1) + defer streamCancel() + oneTurnConnectionLoop(streamCtx, params, peer, connchan, t, 1) }() select { @@ -1954,15 +2146,18 @@ func main() { for i := 1; i < numStreams; i++ { cchan := make(chan net.PacketConn) + streamCtx, streamCancel := context.WithCancel(ctx) wg1.Add(1) go func(streamID int) { defer wg1.Done() - oneDtlsConnectionLoop(ctx, peer, listenConn, inboundChan, cchan, nil, streamID) + defer streamCancel() + oneDtlsConnectionLoop(streamCtx, peer, listenConn, inboundChan, cchan, nil, streamID) }(i) wg1.Add(1) go func(streamID int) { defer wg1.Done() - oneTurnConnectionLoop(ctx, params, peer, cchan, t, streamID) + defer streamCancel() + oneTurnConnectionLoop(streamCtx, params, peer, cchan, t, streamID) }(i) } @@ -2246,16 +2441,19 @@ func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, i cleanup() return nil, nil, fmt.Errorf("DTLS handshake: %w", err) } - cleanupFns = append(cleanupFns, func() { _ = dtlsConn.Close() }) log.Printf("DTLS connection established") // 5. Create KCP session over DTLS - kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, false) + kcpSess, cleanupKCP, err := tcputil.NewKCPOverDTLS(dtlsConn, false) if err != nil { cleanup() return nil, nil, fmt.Errorf("KCP session: %w", err) } - cleanupFns = append(cleanupFns, func() { _ = kcpSess.Close() }) + cleanupFns = append(cleanupFns, func() { + if err := cleanupKCP(); err != nil { + log.Printf("KCP cleanup error: %v", err) + } + }) log.Printf("KCP session established") // 6. Create smux client session over KCP @@ -2326,3 +2524,10 @@ func pipe(ctx context.Context, c1, c2 net.Conn) { log.Printf("pipe: failed to reset deadline c2: %v", err) } } + +func capitalizeFirstLetter(s string) string { + for i, v := range s { + return string(unicode.ToUpper(v)) + s[i+len(string(v)):] + } + return "" +} diff --git a/client/main_test.go b/client/main_test.go index e0f2d779..9eb72cd4 100644 --- a/client/main_test.go +++ b/client/main_test.go @@ -1,6 +1,95 @@ package main -import "testing" +import ( + "bytes" + "context" + "fmt" + "strings" + "sync" + "testing" + "time" + + "github.com/bschaatsbergen/dnsdialer" + "github.com/cacggghp/vk-turn-proxy/internal/cliutil" +) + +func TestParseClientOptionsShowsUsageWithoutArgs(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + var stderr bytes.Buffer + + _, exitCode := parseClientOptions(nil, "client", &stdout, &stderr) + if exitCode != 0 { + t.Fatalf("parseClientOptions() exitCode = %d, want 0", exitCode) + } + if stderr.Len() != 0 { + t.Fatalf("expected no stderr output, got %q", stderr.String()) + } + if got := stdout.String(); !strings.Contains(got, "Usage:\n client -peer -vk-link [flags]") { + t.Fatalf("usage output missing client help text: %q", got) + } +} + +func TestParseClientOptionsShowsHelpFlagUsage(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + var stderr bytes.Buffer + + _, exitCode := parseClientOptions([]string{"-help"}, "client", &stdout, &stderr) + if exitCode != 0 { + t.Fatalf("parseClientOptions() exitCode = %d, want 0", exitCode) + } + if stderr.Len() != 0 { + t.Fatalf("expected no stderr output, got %q", stderr.String()) + } + if got := stdout.String(); !strings.Contains(got, "Examples:") { + t.Fatalf("expected help examples in output, got %q", got) + } +} + +func TestParseClientOptionsRequiresPeer(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + var stderr bytes.Buffer + + _, exitCode := parseClientOptions([]string{"-vk-link", "https://vk.com/call/join/test"}, "client", &stdout, &stderr) + if exitCode != 2 { + t.Fatalf("parseClientOptions() exitCode = %d, want 2", exitCode) + } + if stdout.Len() != 0 { + t.Fatalf("expected no stdout output, got %q", stdout.String()) + } + if got := stderr.String(); !strings.Contains(got, "error: -peer is required") { + t.Fatalf("expected missing peer error, got %q", got) + } +} + +func TestParseClientOptionsParsesValidVKArgs(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + var stderr bytes.Buffer + + opts, exitCode := parseClientOptions([]string{"-peer", "127.0.0.1:56000", "-vk-link", "https://vk.com/call/join/test", "-listen", "127.0.0.1:9001"}, "client", &stdout, &stderr) + if exitCode != cliutil.ContinueExecution { + t.Fatalf("parseClientOptions() exitCode = %d, want %d", exitCode, cliutil.ContinueExecution) + } + if stderr.Len() != 0 { + t.Fatalf("expected no stderr output, got %q", stderr.String()) + } + if opts.peerAddr != "127.0.0.1:56000" { + t.Fatalf("peerAddr = %q, want 127.0.0.1:56000", opts.peerAddr) + } + if opts.vklink != "https://vk.com/call/join/test" { + t.Fatalf("vklink = %q, want VK link", opts.vklink) + } + if opts.listen != "127.0.0.1:9001" { + t.Fatalf("listen = %q, want 127.0.0.1:9001", opts.listen) + } +} func TestCaptchaSolveModeForAttempt(t *testing.T) { t.Parallel() @@ -59,3 +148,105 @@ func TestCaptchaSolveModeForAttempt(t *testing.T) { } }) } + +func TestCaptchaLogAttempt(t *testing.T) { + t.Parallel() + + if got := captchaLogAttempt(context.Background(), captchaSolveModeManual, 0); got != 1 { + t.Fatalf("captchaLogAttempt() = %d, want 1 for first manual attempt", got) + } + + ctx := withCaptchaFailureCount(context.Background(), 1) + if got := captchaLogAttempt(ctx, captchaSolveModeManual, 0); got != 2 { + t.Fatalf("captchaLogAttempt() = %d, want 2 for second bucket manual attempt", got) + } + + if got := captchaLogAttempt(ctx, captchaSolveModeManual, 2); got != 3 { + t.Fatalf("captchaLogAttempt() = %d, want 3 when in-request attempt count is already higher", got) + } + + if got := captchaLogAttempt(ctx, captchaSolveModeAuto, 0); got != 1 { + t.Fatalf("captchaLogAttempt() = %d, want 1 for auto attempt count", got) + } +} + +func TestGetVkCredsCachedDisablesBucketAfterTwoCaptchaFailures(t *testing.T) { + credentialsStore.mu.Lock() + previousCaches := credentialsStore.caches + credentialsStore.caches = make(map[int]*StreamCredentialsCache) + credentialsStore.mu.Unlock() + defer func() { + credentialsStore.mu.Lock() + credentialsStore.caches = previousCaches + credentialsStore.mu.Unlock() + }() + + previousFetch := fetchVkCredsSerializedFunc + defer func() { + fetchVkCredsSerializedFunc = previousFetch + }() + previousConfiguredStreams := configuredStreams.Load() + configuredStreams.Store(20) + defer configuredStreams.Store(previousConfiguredStreams) + + var ( + mu sync.Mutex + callCount int + ) + fetchVkCredsSerializedFunc = func(ctx context.Context, link string, streamID int, dialer *dnsdialer.Dialer) (string, string, string, error) { + mu.Lock() + defer mu.Unlock() + callCount++ + switch callCount { + case 1: + return "", "", "", fmt.Errorf("CAPTCHA_WAIT_REQUIRED") + case 2: + return "", "", "", fmt.Errorf("CAPTCHA_WAIT_REQUIRED") + default: + return "", "", "", fmt.Errorf("unexpected extra fetch for stream %d", streamID) + } + } + + _, _, _, err := getVkCredsCached(context.Background(), "link", 10, nil) + if err == nil || err.Error() != "CAPTCHA_WAIT_REQUIRED" { + t.Fatalf("first getVkCredsCached() error = %v, want CAPTCHA_WAIT_REQUIRED", err) + } + + _, _, _, err = getVkCredsCached(context.Background(), "link", 17, nil) + if err == nil || err.Error() != "CAPTCHA_WAIT_REQUIRED" { + t.Fatalf("second getVkCredsCached() error = %v, want shared CAPTCHA_WAIT_REQUIRED", err) + } + + mu.Lock() + if callCount != 1 { + mu.Unlock() + t.Fatalf("expected one fetch attempt during shared captcha cooldown, got %d", callCount) + } + mu.Unlock() + + cache := getStreamCache(10) + cache.mutex.Lock() + cache.retryAfter = time.Now().Add(-time.Second) + cache.mutex.Unlock() + + _, _, _, err = getVkCredsCached(context.Background(), "link", 19, nil) + if err == nil || !strings.Contains(err.Error(), errCaptchaBucketDisabled.Error()) { + t.Fatalf("third getVkCredsCached() error = %v, want bucket disabled", err) + } + + _, _, _, err = getVkCredsCached(context.Background(), "link", 11, nil) + if err == nil || !strings.Contains(err.Error(), errCaptchaBucketDisabled.Error()) { + t.Fatalf("fourth getVkCredsCached() error = %v, want bucket disabled without new fetch", err) + } + + mu.Lock() + if callCount != 2 { + mu.Unlock() + t.Fatalf("expected exactly two fetch attempts before bucket disable, got %d", callCount) + } + mu.Unlock() + + if got := activeConfiguredStreamCount(); got != 10 { + t.Fatalf("activeConfiguredStreamCount() = %d, want 10 after disabling second bucket", got) + } +} diff --git a/client/manual_captcha.go b/client/manual_captcha.go index 27f958d3..5fa5ae50 100644 --- a/client/manual_captcha.go +++ b/client/manual_captcha.go @@ -23,6 +23,8 @@ import ( const captchaListenPort = "8765" +var openBrowserFunc = openBrowser + type browserCommand struct { name string args []string @@ -374,13 +376,20 @@ func startCaptchaServer(srv *http.Server, logPrefix string) error { return fmt.Errorf("captcha listeners failed: %s", strings.Join(listenErrs, "; ")) } -// runCaptchaServerAndWait triggers the browser, and waiting gracefully for the solution token. -func runCaptchaServerAndWait(handler http.Handler, captchaURL string, keyCh <-chan string, logPrefix string) (string, error) { +// runCaptchaServerAndWait triggers the browser and waits for either a solution token or cancellation. +func runCaptchaServerAndWait(ctx context.Context, handler http.Handler, captchaURL string, keyCh <-chan string, logPrefix string) (string, error) { srv := &http.Server{Handler: handler} if err := startCaptchaServer(srv, logPrefix); err != nil { return "", err } + defer func() { + shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Printf("%s shutdown error: %v", logPrefix, err) + } + }() fmt.Println("\n==============================================") fmt.Println("ACTION REQUIRED: MANUAL CAPTCHA SOLVING NEEDED") @@ -388,17 +397,14 @@ func runCaptchaServerAndWait(handler http.Handler, captchaURL string, keyCh <-ch fmt.Println("==============================================") fmt.Println() - openBrowser(captchaURL) - - key := <-keyCh + openBrowserFunc(captchaURL) - ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) - defer cancel() - if err := srv.Shutdown(ctx); err != nil { - return "", err + select { + case key := <-keyCh: + return key, nil + case <-ctx.Done(): + return "", ctx.Err() } - - return key, nil } // notifyKey pushes the key string to the given channel without blocking @@ -411,7 +417,7 @@ func notifyKey(keyCh chan<- string, key string) { } } -func solveCaptchaViaHTTP(captchaImg string) (string, error) { +func solveCaptchaViaHTTP(ctx context.Context, captchaImg string) (string, error) { keyCh := make(chan string, 1) mux := http.NewServeMux() @@ -439,10 +445,10 @@ button{font-size:24px;padding:12px 32px;margin-top:12px;cursor:pointer} _, _ = fmt.Fprint(w, `

Done!

`) }) - return runCaptchaServerAndWait(mux, localCaptchaOrigin(), keyCh, "captcha HTTP server error") + return runCaptchaServerAndWait(ctx, mux, localCaptchaOrigin(), keyCh, "captcha HTTP server error") } -func solveCaptchaViaProxy(redirectURI string, dialer *dnsdialer.Dialer) (string, error) { +func solveCaptchaViaProxy(ctx context.Context, redirectURI string, dialer *dnsdialer.Dialer) (string, error) { keyCh := make(chan string, 1) targetURL, err := neturl.Parse(redirectURI) @@ -566,7 +572,7 @@ func solveCaptchaViaProxy(redirectURI string, dialer *dnsdialer.Dialer) (string, proxy.ServeHTTP(w, r) }) - return runCaptchaServerAndWait(mux, localCaptchaURLForTarget(targetURL), keyCh, "proxy HTTP server error") + return runCaptchaServerAndWait(ctx, mux, localCaptchaURLForTarget(targetURL), keyCh, "proxy HTTP server error") } func openBrowser(url string) { diff --git a/client/manual_captcha_test.go b/client/manual_captcha_test.go index 8afbafdb..9294d731 100644 --- a/client/manual_captcha_test.go +++ b/client/manual_captcha_test.go @@ -1,8 +1,13 @@ package main import ( + "context" + "errors" + "net" + "net/http" "net/url" "testing" + "time" ) func TestRewriteProxyRedirectLocation(t *testing.T) { @@ -63,3 +68,75 @@ func TestRewriteProxyRedirectLocation(t *testing.T) { }) } } + +func TestRunCaptchaServerAndWaitStopsOnContextCancel(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:"+captchaListenPort) + if err != nil { + t.Skipf("captcha listener test requires a free localhost port: %v", err) + } + if err := listener.Close(); err != nil { + t.Fatalf("failed to release preflight listener: %v", err) + } + + previousOpenBrowser := openBrowserFunc + openBrowserFunc = func(string) {} + defer func() { + openBrowserFunc = previousOpenBrowser + }() + + ctx, cancel := context.WithCancel(context.Background()) + errCh := make(chan error, 1) + + go func() { + _, err := runCaptchaServerAndWait( + ctx, + http.NewServeMux(), + localCaptchaOrigin(), + make(chan string), + "test captcha server", + ) + errCh <- err + }() + + deadline := time.Now().Add(2 * time.Second) + for { + conn, dialErr := net.DialTimeout("tcp", "127.0.0.1:"+captchaListenPort, 50*time.Millisecond) + if dialErr == nil { + if err := conn.Close(); err != nil { + t.Fatalf("failed to close probe connection: %v", err) + } + break + } + if time.Now().After(deadline) { + cancel() + t.Fatalf("captcha server did not start listening: %v", dialErr) + } + time.Sleep(20 * time.Millisecond) + } + + cancel() + + select { + case err := <-errCh: + if !errors.Is(err, context.Canceled) { + t.Fatalf("runCaptchaServerAndWait() error = %v, want context canceled", err) + } + case <-time.After(3 * time.Second): + t.Fatal("runCaptchaServerAndWait() did not return after context cancellation") + } + + deadline = time.Now().Add(2 * time.Second) + for { + listener, err = net.Listen("tcp", "127.0.0.1:"+captchaListenPort) + if err == nil { + if closeErr := listener.Close(); closeErr != nil { + t.Fatalf("failed to close verification listener: %v", closeErr) + } + return + } + if time.Now().After(deadline) { + t.Fatalf("captcha listener was not released after cancellation: %v", err) + } + time.Sleep(20 * time.Millisecond) + } +} diff --git a/client/slider_captcha.go b/client/slider_captcha.go index 166a6abd..1cadd515 100644 --- a/client/slider_captcha.go +++ b/client/slider_captcha.go @@ -144,7 +144,7 @@ func (s *captchaNotRobotSession) requestComponentDone() error { respObj, ok := resp["response"].(map[string]interface{}) if ok { - if status, _ := respObj["status"].(string); status != "" && status != "OK" { + if status, ok := respObj["status"].(string); ok && status != "" && status != "OK" { return fmt.Errorf("componentDone status: %s", status) } } @@ -316,7 +316,9 @@ func parseCaptchaSettingsResponse(resp map[string]interface{}) (*captchaSettings settings := &captchaSettingsResponse{ SettingsByType: make(map[string]string), } - settings.ShowCaptchaType, _ = respObj["show_captcha_type"].(string) + if showCaptchaType, ok := respObj["show_captcha_type"].(string); ok { + settings.ShowCaptchaType = showCaptchaType + } rawSettings, ok := expandCaptchaSettings(respObj["captcha_settings"]) if !ok { @@ -329,8 +331,8 @@ func parseCaptchaSettingsResponse(resp map[string]interface{}) (*captchaSettings continue } - captchaType, _ := item["type"].(string) - if captchaType == "" { + captchaType, ok := item["type"].(string) + if !ok || captchaType == "" { continue } @@ -495,9 +497,15 @@ func parseCaptchaCheckResult(resp map[string]interface{}) (*captchaCheckResult, } result := &captchaCheckResult{} - result.Status, _ = respObj["status"].(string) - result.SuccessToken, _ = respObj["success_token"].(string) - result.ShowCaptchaType, _ = respObj["show_captcha_type"].(string) + if status, ok := respObj["status"].(string); ok { + result.Status = status + } + if successToken, ok := respObj["success_token"].(string); ok { + result.SuccessToken = successToken + } + if showCaptchaType, ok := respObj["show_captcha_type"].(string); ok { + result.ShowCaptchaType = showCaptchaType + } if result.Status == "" { return nil, fmt.Errorf("check status missing: %v", resp) } @@ -511,19 +519,22 @@ func parseSliderCaptchaContentResponse(resp map[string]interface{}) (*sliderCapt return nil, fmt.Errorf("invalid slider content response: %v", resp) } - status, _ := respObj["status"].(string) - if status != "OK" { + status, ok := respObj["status"].(string) + if !ok || status != "OK" { return nil, fmt.Errorf("slider getContent status: %s", status) } - extension, _ := respObj["extension"].(string) + extension, ok := respObj["extension"].(string) + if !ok { + return nil, fmt.Errorf("unsupported slider image format: %v", respObj["extension"]) + } extension = strings.ToLower(extension) if extension != "jpeg" && extension != "jpg" { return nil, fmt.Errorf("unsupported slider image format: %s", extension) } - rawImage, _ := respObj["image"].(string) - if rawImage == "" { + rawImage, ok := respObj["image"].(string) + if !ok || rawImage == "" { return nil, fmt.Errorf("slider image missing") } diff --git a/internal/cliutil/cliutil.go b/internal/cliutil/cliutil.go new file mode 100644 index 00000000..bb5282be --- /dev/null +++ b/internal/cliutil/cliutil.go @@ -0,0 +1,71 @@ +package cliutil + +import ( + "errors" + "flag" + "fmt" + "io" +) + +const ContinueExecution = -1 + +type Builder[T any] func(program string, output io.Writer) (*flag.FlagSet, *T) + +type Validator[T any] func(*T) error + +func Parse[T any]( + args []string, + program string, + stdout io.Writer, + stderr io.Writer, + build Builder[T], + validate Validator[T], +) (T, int) { + var zero T + + if len(args) == 0 { + fs, _ := build(program, stdout) + fs.Usage() + return zero, 0 + } + + output := stderr + if hasHelpFlag(args) { + output = stdout + } + + fs, opts := build(program, output) + if err := fs.Parse(args); err != nil { + if errors.Is(err, flag.ErrHelp) { + return zero, 0 + } + return zero, 2 + } + + if validate != nil { + if err := validate(opts); err != nil { + Fprintln(stderr, "error:", err) + fs.Usage() + return zero, 2 + } + } + + return *opts, ContinueExecution +} + +func Fprintf(w io.Writer, format string, args ...any) { + _, _ = fmt.Fprintf(w, format, args...) +} + +func Fprintln(w io.Writer, args ...any) { + _, _ = fmt.Fprintln(w, args...) +} + +func hasHelpFlag(args []string) bool { + for _, arg := range args { + if arg == "-h" || arg == "-help" || arg == "--help" { + return true + } + } + return false +} diff --git a/server/main.go b/server/main.go index cdf91d43..e02c3a42 100644 --- a/server/main.go +++ b/server/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "errors" "flag" "fmt" "io" @@ -9,21 +10,58 @@ import ( "net" "os" "os/signal" + "path/filepath" "sync" "syscall" "time" + "github.com/cacggghp/vk-turn-proxy/internal/cliutil" "github.com/cacggghp/vk-turn-proxy/tcputil" "github.com/pion/dtls/v3" "github.com/pion/dtls/v3/pkg/crypto/selfsign" "github.com/xtaci/smux" ) +type serverOptions struct { + listen string + connect string + vlessMode bool +} + +func newServerFlagSet(program string, output io.Writer) (*flag.FlagSet, *serverOptions) { + fs := flag.NewFlagSet(program, flag.ContinueOnError) + fs.SetOutput(output) + + opts := &serverOptions{} + fs.StringVar(&opts.listen, "listen", "0.0.0.0:56000", "listen on ip:port") + fs.StringVar(&opts.connect, "connect", "", "connect to ip:port") + fs.BoolVar(&opts.vlessMode, "vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets") + fs.Usage = func() { + cliutil.Fprintf(fs.Output(), "Usage:\n %s -connect [flags]\n\n", program) + cliutil.Fprintln(fs.Output(), "Examples:") + cliutil.Fprintf(fs.Output(), " %s -connect 127.0.0.1:51820\n", program) + cliutil.Fprintf(fs.Output(), " %s -listen 0.0.0.0:56000 -connect 127.0.0.1:51820 -vless\n\n", program) + cliutil.Fprintln(fs.Output(), "Flags:") + fs.PrintDefaults() + } + + return fs, opts +} + +func parseServerOptions(args []string, program string, stdout, stderr io.Writer) (serverOptions, int) { + return cliutil.Parse(args, program, stdout, stderr, newServerFlagSet, func(opts *serverOptions) error { + if opts.connect == "" { + return fmt.Errorf("-connect is required") + } + return nil + }) +} + func main() { - listen := flag.String("listen", "0.0.0.0:56000", "listen on ip:port") - connect := flag.String("connect", "", "connect to ip:port") - vlessMode := flag.Bool("vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets") - flag.Parse() + opts, exitCode := parseServerOptions(os.Args[1:], filepath.Base(os.Args[0]), os.Stdout, os.Stderr) + if exitCode != cliutil.ContinueExecution { + os.Exit(exitCode) + } ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -37,23 +75,16 @@ func main() { log.Fatalf("Exit...\n") }() - addr, err := net.ResolveUDPAddr("udp", *listen) + addr, err := net.ResolveUDPAddr("udp", opts.listen) if err != nil { panic(err) } - if len(*connect) == 0 { - log.Panicf("server address is required") - } // Generate a certificate and private key to secure the connection certificate, genErr := selfsign.GenerateSelfSigned() if genErr != nil { panic(genErr) } - // - // Everything below is the pion-DTLS API! Thanks for using it ❤️. - // - // Connect to a DTLS server listener, err := dtls.ListenWithOptions( "udp", @@ -72,7 +103,7 @@ func main() { } }) - fmt.Println("Listening") + cliutil.Fprintln(os.Stdout, "Listening") wg1 := sync.WaitGroup{} for { @@ -91,7 +122,11 @@ func main() { wg1.Add(1) go func(conn net.Conn) { defer wg1.Done() + closeConn := true defer func() { + if !closeConn { + return + } if closeErr := conn.Close(); closeErr != nil { log.Printf("failed to close incoming connection: %s", closeErr) } @@ -114,10 +149,11 @@ func main() { } log.Println("Handshake done") - if *vlessMode { - handleVLESSConnection(ctx, dtlsConn, *connect) + if opts.vlessMode { + closeConn = false + handleVLESSConnection(ctx, dtlsConn, opts.connect) } else { - handleUDPConnection(ctx, conn, *connect) + handleUDPConnection(ctx, conn, opts.connect) } log.Printf("Connection closed: %s\n", conn.RemoteAddr()) @@ -149,83 +185,24 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string) log.Printf("failed to set outgoing deadline: %s", err) } }) - go func() { - defer wg.Done() - defer cancel2() - buf := make([]byte, 1600) - for { - select { - case <-ctx2.Done(): - return - default: - } - if err1 := conn.SetReadDeadline(time.Now().Add(time.Minute * 30)); err1 != nil { - log.Printf("Failed: %s", err1) - return - } - n, err1 := conn.Read(buf) - if err1 != nil { - log.Printf("Failed: %s", err1) - return - } - - if err1 = serverConn.SetWriteDeadline(time.Now().Add(time.Minute * 30)); err1 != nil { - log.Printf("Failed: %s", err1) - return - } - _, err1 = serverConn.Write(buf[:n]) - if err1 != nil { - log.Printf("Failed: %s", err1) - return - } - } - }() - go func() { - defer wg.Done() - defer cancel2() - buf := make([]byte, 1600) - for { - select { - case <-ctx2.Done(): - return - default: - } - if err1 := serverConn.SetReadDeadline(time.Now().Add(time.Minute * 30)); err1 != nil { - log.Printf("Failed: %s", err1) - return - } - n, err1 := serverConn.Read(buf) - if err1 != nil { - log.Printf("Failed: %s", err1) - return - } - - if err1 = conn.SetWriteDeadline(time.Now().Add(time.Minute * 30)); err1 != nil { - log.Printf("Failed: %s", err1) - return - } - _, err1 = conn.Write(buf[:n]) - if err1 != nil { - log.Printf("Failed: %s", err1) - return - } - } - }() + startPacketForwarder(ctx2, &wg, cancel2, conn, serverConn) + startPacketForwarder(ctx2, &wg, cancel2, serverConn, conn) wg.Wait() } // handleVLESSConnection creates a KCP+smux session over DTLS and forwards // each smux stream as a TCP connection to the backend (Xray/VLESS). +// It takes ownership of dtlsConn and closes it through the KCP cleanup path. func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr string) { // 1. Create KCP session over DTLS - kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, true) + kcpSess, cleanupKCP, err := tcputil.NewKCPOverDTLS(dtlsConn, true) if err != nil { log.Printf("KCP session error: %s", err) return } defer func() { - if err := kcpSess.Close(); err != nil { - log.Printf("failed to close KCP session: %v", err) + if err := cleanupKCP(); err != nil { + log.Printf("failed to close KCP-over-DTLS transport: %v", err) } }() log.Printf("KCP session established (server)") @@ -261,7 +238,7 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s defer wg.Done() defer func() { - if err := s.Close(); err != nil && err != smux.ErrGoAway { + if err := s.Close(); err != nil && !errors.Is(err, smux.ErrGoAway) { log.Printf("failed to close smux stream: %v", err) } }() @@ -301,24 +278,58 @@ func pipeConn(ctx context.Context, c1, c2 net.Conn) { var wg sync.WaitGroup wg.Add(2) + startStreamCopy(&wg, cancel, c1, c2, "pipeConn: c1<-c2") + startStreamCopy(&wg, cancel, c2, c1, "pipeConn: c2<-c1") + + wg.Wait() + // Reset deadlines + _ = c1.SetDeadline(time.Time{}) + _ = c2.SetDeadline(time.Time{}) +} + +func startPacketForwarder(ctx context.Context, wg *sync.WaitGroup, cancel context.CancelFunc, src, dst net.Conn) { go func() { defer wg.Done() - if _, err := io.Copy(c1, c2); err != nil { - log.Printf("pipeConn: c1<-c2 copy error: %v", err) + defer cancel() + + buf := make([]byte, 1600) + for { + select { + case <-ctx.Done(): + return + default: + } + + if err := src.SetReadDeadline(time.Now().Add(30 * time.Minute)); err != nil { + log.Printf("Failed: %s", err) + return + } + n, err := src.Read(buf) + if err != nil { + log.Printf("Failed: %s", err) + return + } + + if err = dst.SetWriteDeadline(time.Now().Add(30 * time.Minute)); err != nil { + log.Printf("Failed: %s", err) + return + } + if _, err = dst.Write(buf[:n]); err != nil { + log.Printf("Failed: %s", err) + return + } } }() +} +func startStreamCopy(wg *sync.WaitGroup, cancel context.CancelFunc, dst, src net.Conn, label string) { go func() { defer wg.Done() - if _, err := io.Copy(c2, c1); err != nil { - log.Printf("pipeConn: c2<-c1 copy error: %v", err) + defer cancel() + + if _, err := io.Copy(dst, src); err != nil { + log.Printf("%s copy error: %v", label, err) } }() - - wg.Wait() - - // Reset deadlines - _ = c1.SetDeadline(time.Time{}) - _ = c2.SetDeadline(time.Time{}) } diff --git a/server/main_test.go b/server/main_test.go new file mode 100644 index 00000000..01e87074 --- /dev/null +++ b/server/main_test.go @@ -0,0 +1,87 @@ +package main + +import ( + "bytes" + "strings" + "testing" + + "github.com/cacggghp/vk-turn-proxy/internal/cliutil" +) + +func TestParseServerOptionsShowsUsageWithoutArgs(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + var stderr bytes.Buffer + + _, exitCode := parseServerOptions(nil, "server", &stdout, &stderr) + if exitCode != 0 { + t.Fatalf("parseServerOptions() exitCode = %d, want 0", exitCode) + } + if stderr.Len() != 0 { + t.Fatalf("expected no stderr output, got %q", stderr.String()) + } + if got := stdout.String(); !strings.Contains(got, "Usage:\n server -connect [flags]") { + t.Fatalf("usage output missing server help text: %q", got) + } +} + +func TestParseServerOptionsShowsHelpFlagUsage(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + var stderr bytes.Buffer + + _, exitCode := parseServerOptions([]string{"-help"}, "server", &stdout, &stderr) + if exitCode != 0 { + t.Fatalf("parseServerOptions() exitCode = %d, want 0", exitCode) + } + if stderr.Len() != 0 { + t.Fatalf("expected no stderr output, got %q", stderr.String()) + } + if got := stdout.String(); !strings.Contains(got, "Examples:") { + t.Fatalf("expected help examples in output, got %q", got) + } +} + +func TestParseServerOptionsRequiresConnect(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + var stderr bytes.Buffer + + _, exitCode := parseServerOptions([]string{"-listen", "0.0.0.0:56000"}, "server", &stdout, &stderr) + if exitCode != 2 { + t.Fatalf("parseServerOptions() exitCode = %d, want 2", exitCode) + } + if stdout.Len() != 0 { + t.Fatalf("expected no stdout output, got %q", stdout.String()) + } + if got := stderr.String(); !strings.Contains(got, "error: -connect is required") { + t.Fatalf("expected missing connect error, got %q", got) + } +} + +func TestParseServerOptionsParsesValidArgs(t *testing.T) { + t.Parallel() + + var stdout bytes.Buffer + var stderr bytes.Buffer + + opts, exitCode := parseServerOptions([]string{"-connect", "127.0.0.1:51820", "-listen", "0.0.0.0:56000", "-vless"}, "server", &stdout, &stderr) + if exitCode != cliutil.ContinueExecution { + t.Fatalf("parseServerOptions() exitCode = %d, want %d", exitCode, cliutil.ContinueExecution) + } + if stderr.Len() != 0 { + t.Fatalf("expected no stderr output, got %q", stderr.String()) + } + if opts.connect != "127.0.0.1:51820" { + t.Fatalf("connect = %q, want 127.0.0.1:51820", opts.connect) + } + if opts.listen != "0.0.0.0:56000" { + t.Fatalf("listen = %q, want 0.0.0.0:56000", opts.listen) + } + if !opts.vlessMode { + t.Fatal("vlessMode = false, want true") + } +} diff --git a/tcputil/tcputil.go b/tcputil/tcputil.go index 4e37ea33..40d0684d 100644 --- a/tcputil/tcputil.go +++ b/tcputil/tcputil.go @@ -1,83 +1,117 @@ package tcputil import ( + "errors" + "io" "net" + "sync" "time" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" ) -// DtlsPacketConn wraps a net.Conn (DTLS) as a net.PacketConn for KCP. -// Each DTLS Read/Write preserves message boundaries (datagram semantics). -type DtlsPacketConn struct { +// dtlsPacketConn adapts a DTLS net.Conn to net.PacketConn for KCP. +// It does not own the underlying transport; callers must close dtlsConn +// through the cleanup function returned by NewKCPOverDTLS. +type dtlsPacketConn struct { conn net.Conn } -func NewDtlsPacketConn(conn net.Conn) *DtlsPacketConn { - return &DtlsPacketConn{conn: conn} -} - -func (d *DtlsPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { +func (d *dtlsPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { n, err := d.conn.Read(b) return n, d.conn.RemoteAddr(), err } -func (d *DtlsPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) { +func (d *dtlsPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) { return d.conn.Write(b) } -func (d *DtlsPacketConn) Close() error { - return d.conn.Close() +func (d *dtlsPacketConn) Close() error { + return nil } -func (d *DtlsPacketConn) LocalAddr() net.Addr { +func (d *dtlsPacketConn) LocalAddr() net.Addr { return d.conn.LocalAddr() } -func (d *DtlsPacketConn) SetDeadline(t time.Time) error { +func (d *dtlsPacketConn) SetDeadline(t time.Time) error { return d.conn.SetDeadline(t) } -func (d *DtlsPacketConn) SetReadDeadline(t time.Time) error { +func (d *dtlsPacketConn) SetReadDeadline(t time.Time) error { return d.conn.SetReadDeadline(t) } -func (d *DtlsPacketConn) SetWriteDeadline(t time.Time) error { +func (d *dtlsPacketConn) SetWriteDeadline(t time.Time) error { return d.conn.SetWriteDeadline(t) } -// NewKCPOverDTLS creates a KCP session over a DTLS connection. +// NewKCPOverDTLS creates a KCP session over a DTLS connection and returns +// an idempotent cleanup function for the entire KCP-over-DTLS transport. +// After a successful call, the caller should use the returned cleanup instead +// of closing dtlsConn directly. +// // isServer: true for server-side (listener), false for client-side (dialer). -func NewKCPOverDTLS(dtlsConn net.Conn, isServer bool) (*kcp.UDPSession, error) { - pc := NewDtlsPacketConn(dtlsConn) +func NewKCPOverDTLS(dtlsConn net.Conn, isServer bool) (_ *kcp.UDPSession, cleanup func() error, err error) { + var ( + listener *kcp.Listener + sess *kcp.UDPSession + closeErr error + closeOnce sync.Once + ) + transportCleanup := func() error { + closeOnce.Do(func() { + var errs []error + if sess != nil { + if err := sess.Close(); err != nil && !errors.Is(err, io.ErrClosedPipe) { + errs = append(errs, err) + } + } + if listener != nil { + if err := listener.Close(); err != nil && !errors.Is(err, io.ErrClosedPipe) { + errs = append(errs, err) + } + } + if err := dtlsConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.ErrClosedPipe) { + errs = append(errs, err) + } + closeErr = errors.Join(errs...) + }) + return closeErr + } + defer func() { + if err == nil { + return + } + if cleanupErr := transportCleanup(); cleanupErr != nil { + err = errors.Join(err, cleanupErr) + } + }() block, err := kcp.NewNoneBlockCrypt(nil) // DTLS already encrypts if err != nil { - return nil, err + return nil, nil, err } - var sess *kcp.UDPSession - if isServer { // Server: listen on the PacketConn and accept one session - var listener *kcp.Listener - listener, err = kcp.ServeConn(block, 0, 0, pc) + listener, err = kcp.ServeConn(block, 0, 0, &dtlsPacketConn{conn: dtlsConn}) if err != nil { - return nil, err + return nil, nil, err } if err = listener.SetDeadline(time.Now().Add(30 * time.Second)); err != nil { - return nil, err + return nil, nil, err } sess, err = listener.AcceptKCP() if err != nil { - return nil, err + return nil, nil, err } } else { // Client: dial through the PacketConn - sess, err = kcp.NewConn2(dtlsConn.RemoteAddr(), block, 0, 0, pc) + sess, err = kcp.NewConn2(dtlsConn.RemoteAddr(), block, 0, 0, &dtlsPacketConn{conn: dtlsConn}) if err != nil { - return nil, err + return nil, nil, err } } @@ -89,7 +123,7 @@ func NewKCPOverDTLS(dtlsConn net.Conn, isServer bool) (*kcp.UDPSession, error) { sess.SetMtu(1200) // conservative MTU to fit inside DTLS+TURN sess.SetACKNoDelay(true) - return sess, nil + return sess, transportCleanup, nil } // DefaultSmuxConfig returns smux config tuned for TURN tunnel.