diff --git a/.github/workflows/issues.yml b/.github/workflows/issues.yml
index 2eee1c7e..913a7df8 100644
--- a/.github/workflows/issues.yml
+++ b/.github/workflows/issues.yml
@@ -1,7 +1,7 @@
name: "Close Stale Issues"
on:
schedule:
- - cron: "0 0 * * 3"
+ - cron: "0 0 * * 0"
workflow_dispatch:
jobs:
@@ -20,7 +20,7 @@ jobs:
Спасибо!
close-issue-message: "Issue был закрыт из-за отсутствия активности."
- days-before-stale: 120
+ days-before-stale: 30
days-before-close: 7
operations-per-run: 1000
ascending: true
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 3013fe95..3e1bc29f 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -78,7 +78,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v6
with:
- go-version: "1.25.5"
+ go-version: "1.26.2"
# Install Android NDK only for android rows
- name: Setup Android NDK
diff --git a/.github/workflows/.golangci.yml b/.golangci.yml
similarity index 88%
rename from .github/workflows/.golangci.yml
rename to .golangci.yml
index db48597c..59006559 100644
--- a/.github/workflows/.golangci.yml
+++ b/.golangci.yml
@@ -1,11 +1,18 @@
version: "2"
run:
+ modules-download-mode: readonly
+ relative-path-mode: gomod
tests: true
linters:
enable:
- bodyclose
+ - errcheck
+ - govet
+ - ineffassign
- misspell
- revive
+ - staticcheck
+ - unused
settings:
errcheck:
check-type-assertions: true
@@ -13,12 +20,12 @@ linters:
exclude-functions:
- (net.PacketConn).WriteTo
- (net.Conn).Write
+ - (net.Conn).SetDeadline
- encoding/json.MarshalIndent
- (*github.com/pion/dtls/v3.Conn).SetDeadline
govet:
disable:
- fieldalignment
- enable-all: true
revive:
rules:
- name: blank-imports
@@ -52,10 +59,7 @@ linters:
issues:
max-issues-per-linter: 0
max-same-issues: 0
- exclude-rules:
- - linters:
- - errcheck
- source: "doRequest|packetPool\\.Get"
+
formatters:
exclusions:
generated: lax
diff --git a/client/main.go b/client/main.go
index 71a86251..b7ec2a95 100644
--- a/client/main.go
+++ b/client/main.go
@@ -11,6 +11,7 @@ import (
"encoding/base64"
"encoding/hex"
"encoding/json"
+ "errors"
"flag"
"fmt"
"io"
@@ -21,18 +22,21 @@ import (
neturl "net/url"
"os"
"os/signal"
+ "path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
+ "unicode"
fhttp "github.com/bogdanfinn/fhttp"
tlsclient "github.com/bogdanfinn/tls-client"
"github.com/bogdanfinn/tls-client/profiles"
"github.com/bschaatsbergen/dnsdialer"
+ "github.com/cacggghp/vk-turn-proxy/internal/cliutil"
"github.com/cacggghp/vk-turn-proxy/tcputil"
"github.com/cbeuw/connutil"
"github.com/google/uuid"
@@ -62,6 +66,7 @@ var (
activeLocalPeer atomic.Value
globalCaptchaLockout atomic.Int64
connectedStreams atomic.Int32
+ configuredStreams atomic.Int32
globalAppCancel context.CancelFunc
handshakeSem = make(chan struct{}, 3)
isDebug bool
@@ -70,6 +75,7 @@ var (
)
type captchaSolveMode int
+type captchaFailureCountContextKey struct{}
const (
captchaSolveModeAuto captchaSolveMode = iota
@@ -77,6 +83,62 @@ const (
captchaSolveModeManual
)
+type clientOptions struct {
+ host string
+ port string
+ listen string
+ vklink string
+ yalink string
+ peerAddr string
+ n int
+ udp bool
+ direct bool
+ vlessMode bool
+ debug bool
+ manualCaptcha bool
+}
+
+func newClientFlagSet(program string, output io.Writer) (*flag.FlagSet, *clientOptions) {
+ fs := flag.NewFlagSet(program, flag.ContinueOnError)
+ fs.SetOutput(output)
+
+ opts := &clientOptions{}
+ fs.StringVar(&opts.host, "turn", "", "override TURN server ip")
+ fs.StringVar(&opts.port, "port", "", "override TURN port")
+ fs.StringVar(&opts.listen, "listen", "127.0.0.1:9000", "listen on ip:port")
+ fs.StringVar(&opts.vklink, "vk-link", "", "VK calls invite link \"https://vk.com/call/join/...\"")
+ fs.StringVar(&opts.yalink, "yandex-link", "", "Yandex Telemost invite link \"https://telemost.yandex.ru/j/...\"")
+ fs.StringVar(&opts.peerAddr, "peer", "", "peer server address (host:port)")
+ fs.IntVar(&opts.n, "n", 0, "connections to TURN (default 10 for VK, 1 for Yandex)")
+ fs.BoolVar(&opts.udp, "udp", false, "connect to TURN with UDP")
+ fs.BoolVar(&opts.direct, "no-dtls", false, "connect without obfuscation. DO NOT USE")
+ fs.BoolVar(&opts.vlessMode, "vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets")
+ fs.BoolVar(&opts.debug, "debug", false, "enable debug logging")
+ fs.BoolVar(&opts.manualCaptcha, "manual-captcha", false, "skip auto captcha solving, use manual mode immediately")
+ fs.Usage = func() {
+ cliutil.Fprintf(fs.Output(), "Usage:\n %s -peer -vk-link [flags]\n %s -peer -yandex-link [flags]\n\n", program, program)
+ cliutil.Fprintln(fs.Output(), "Examples:")
+ cliutil.Fprintf(fs.Output(), " %s -listen 127.0.0.1:9000 -peer 203.0.113.10:56000 -vk-link https://vk.com/call/join/...\n", program)
+ cliutil.Fprintf(fs.Output(), " %s -udp -turn 5.255.211.241 -peer 203.0.113.10:56000 -yandex-link https://telemost.yandex.ru/j/... -listen 127.0.0.1:9000\n\n", program)
+ cliutil.Fprintln(fs.Output(), "Flags:")
+ fs.PrintDefaults()
+ }
+
+ return fs, opts
+}
+
+func parseClientOptions(args []string, program string, stdout, stderr io.Writer) (clientOptions, int) {
+ return cliutil.Parse(args, program, stdout, stderr, newClientFlagSet, func(opts *clientOptions) error {
+ if opts.peerAddr == "" {
+ return fmt.Errorf("-peer is required")
+ }
+ if (opts.vklink == "") == (opts.yalink == "") {
+ return fmt.Errorf("exactly one of -vk-link or -yandex-link is required")
+ }
+ return nil
+ })
+}
+
func captchaSolveModeForAttempt(attempt int, manualOnly bool, enableSliderPOC bool) (captchaSolveMode, bool) {
if manualOnly {
return captchaSolveModeManual, attempt == 0
@@ -112,6 +174,35 @@ func captchaSolveModeLabel(mode captchaSolveMode) string {
}
}
+func withCaptchaFailureCount(ctx context.Context, count int) context.Context {
+ if count <= 0 {
+ return ctx
+ }
+ return context.WithValue(ctx, captchaFailureCountContextKey{}, count)
+}
+
+func captchaFailureCountFromContext(ctx context.Context) int {
+ if ctx == nil {
+ return 0
+ }
+ count, ok := ctx.Value(captchaFailureCountContextKey{}).(int)
+ if !ok {
+ return 0
+ }
+ return count
+}
+
+func captchaLogAttempt(ctx context.Context, solveMode captchaSolveMode, attempt int) int {
+ displayAttempt := attempt + 1
+ if solveMode == captchaSolveModeManual {
+ bucketAttempt := captchaFailureCountFromContext(ctx) + 1
+ if bucketAttempt > displayAttempt {
+ displayAttempt = bucketAttempt
+ }
+ }
+ return displayAttempt
+}
+
type UDPPacket struct {
Data []byte
N int
@@ -611,24 +702,71 @@ type TurnCredentials struct {
}
type StreamCredentialsCache struct {
- creds TurnCredentials
- mutex sync.RWMutex
- errorCount atomic.Int32
- lastErrorTime atomic.Int64
+ creds TurnCredentials
+ mutex sync.RWMutex
+ errorCount atomic.Int32
+ lastErrorTime atomic.Int64
+ captchaFailures int
+ disabled bool
+ disableAnnounced bool
+ retryAfter time.Time
+ retryErr error
+ retryLink string
}
const (
credentialLifetime = 10 * time.Minute
cacheSafetyMargin = 60 * time.Second
maxCacheErrors = 3
+ maxCaptchaFailures = 2
errorWindow = 10 * time.Second
streamsPerCache = 10
+ cacheRetryDelay = 60 * time.Second
)
+var errCaptchaBucketDisabled = errors.New("CAPTCHA_BUCKET_DISABLED")
+
func getCacheID(streamID int) int {
return streamID / streamsPerCache
}
+func bucketStreamCount(cacheID int, numStreams int) int {
+ start := cacheID * streamsPerCache
+ if start >= numStreams {
+ return 0
+ }
+ end := start + streamsPerCache
+ if end > numStreams {
+ end = numStreams
+ }
+ return end - start
+}
+
+func activeConfiguredStreamCount() int {
+ numStreams := int(configuredStreams.Load())
+ if numStreams <= 0 {
+ return 0
+ }
+
+ disabledStreams := 0
+ credentialsStore.mu.RLock()
+ defer credentialsStore.mu.RUnlock()
+ for cacheID, cache := range credentialsStore.caches {
+ cache.mutex.RLock()
+ disabled := cache.disabled
+ cache.mutex.RUnlock()
+ if disabled {
+ disabledStreams += bucketStreamCount(cacheID, numStreams)
+ }
+ }
+
+ activeStreams := numStreams - disabledStreams
+ if activeStreams < 0 {
+ return 0
+ }
+ return activeStreams
+}
+
func vkDelayRandom(minMs, maxMs int) {
ms := minMs + rand.Intn(maxMs-minMs+1)
time.Sleep(time.Duration(ms) * time.Millisecond)
@@ -702,6 +840,12 @@ func handleAuthError(streamID int) bool {
func (c *StreamCredentialsCache) invalidate(streamID int) {
c.mutex.Lock()
c.creds = TurnCredentials{}
+ c.captchaFailures = 0
+ c.disabled = false
+ c.disableAnnounced = false
+ c.retryAfter = time.Time{}
+ c.retryErr = nil
+ c.retryLink = ""
c.mutex.Unlock()
c.errorCount.Store(0)
@@ -710,6 +854,33 @@ func (c *StreamCredentialsCache) invalidate(streamID int) {
log.Printf("[STREAM %d] [VK Auth] Credentials cache invalidated", streamID)
}
+func isCacheBucketDisabled(streamID int) bool {
+ cache := getStreamCache(streamID)
+ cache.mutex.RLock()
+ defer cache.mutex.RUnlock()
+ return cache.disabled
+}
+
+func announceDisabledBucket(streamID int) {
+ cache := getStreamCache(streamID)
+ cacheID := getCacheID(streamID)
+
+ cache.mutex.Lock()
+ if cache.disableAnnounced {
+ cache.mutex.Unlock()
+ return
+ }
+ cache.disableAnnounced = true
+ cache.mutex.Unlock()
+
+ log.Printf(
+ "[VK Auth] Cache bucket %d disabled after %d captcha failures. Continuing with %d streams.",
+ cacheID,
+ maxCaptchaFailures,
+ activeConfiguredStreamCount(),
+ )
+}
+
func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dnsdialer.Dialer) (string, string, string, error) {
cache := getStreamCache(streamID)
cacheID := getCacheID(streamID)
@@ -733,19 +904,48 @@ func getVkCredsCached(ctx context.Context, link string, streamID int, dialer *dn
if cache.creds.Link == link && time.Now().Before(cache.creds.ExpiresAt) {
return cache.creds.Username, cache.creds.Password, cache.creds.ServerAddr, nil
}
+ if cache.disabled {
+ return "", "", "", fmt.Errorf("%w: cache=%d", errCaptchaBucketDisabled, cacheID)
+ }
+ if cache.retryErr != nil && cache.retryLink == link && time.Now().Before(cache.retryAfter) {
+ if isDebug {
+ log.Printf("[STREAM %d] [VK Auth] Reusing cache=%d captcha cooldown (%v remaining)", streamID, cacheID, time.Until(cache.retryAfter).Truncate(time.Millisecond))
+ }
+ return "", "", "", cache.retryErr
+ }
- user, pass, addr, err := fetchVkCredsSerialized(ctx, link, streamID, dialer)
+ user, pass, addr, err := fetchVkCredsSerializedFunc(withCaptchaFailureCount(ctx, cache.captchaFailures), link, streamID, dialer)
if err != nil {
+ if strings.Contains(err.Error(), "CAPTCHA_WAIT_REQUIRED") {
+ cache.captchaFailures++
+ cache.retryAfter = time.Now().Add(cacheRetryDelay)
+ cache.retryErr = err
+ cache.retryLink = link
+ if cache.captchaFailures >= maxCaptchaFailures {
+ cache.disabled = true
+ cache.retryAfter = time.Time{}
+ cache.retryErr = nil
+ cache.retryLink = ""
+ return "", "", "", fmt.Errorf("%w: cache=%d", errCaptchaBucketDisabled, cacheID)
+ }
+ }
return "", "", "", err
}
+ cache.captchaFailures = 0
+ cache.disabled = false
+ cache.disableAnnounced = false
+ cache.retryAfter = time.Time{}
+ cache.retryErr = nil
+ cache.retryLink = ""
cache.creds = TurnCredentials{Username: user, Password: pass, ServerAddr: addr, ExpiresAt: time.Now().Add(credentialLifetime - cacheSafetyMargin), Link: link}
return user, pass, addr, nil
}
var (
- vkRequestMu sync.Mutex
- globalLastVkFetchTime time.Time
+ fetchVkCredsSerializedFunc = fetchVkCredsSerialized
+ vkRequestMu sync.Mutex
+ globalLastVkFetchTime time.Time
)
func fetchVkCredsSerialized(ctx context.Context, link string, streamID int, dialer *dnsdialer.Dialer) (string, string, string, error) {
@@ -958,40 +1158,23 @@ func getTokenChain(ctx context.Context, link string, streamID int, creds VKCrede
log.Printf("[STREAM %d] [Captcha] Triggering manual captcha fallback...", streamID)
manualCtx, manualCancel := context.WithTimeout(ctx, 60*time.Second)
- type manualRes struct {
- token string
- key string
- err error
+ if captchaErr.RedirectURI != "" {
+ successToken, solveErr = solveCaptchaViaProxy(manualCtx, captchaErr.RedirectURI, dialer)
+ } else if captchaErr.CaptchaImg != "" {
+ captchaKey, solveErr = solveCaptchaViaHTTP(manualCtx, captchaErr.CaptchaImg)
+ } else {
+ solveErr = fmt.Errorf("no redirect_uri or captcha_img")
}
- resCh := make(chan manualRes, 1)
-
- go func() {
- var t, k string
- var e error
- if captchaErr.RedirectURI != "" {
- t, e = solveCaptchaViaProxy(captchaErr.RedirectURI, dialer)
- } else if captchaErr.CaptchaImg != "" {
- k, e = solveCaptchaViaHTTP(captchaErr.CaptchaImg)
- } else {
- e = fmt.Errorf("no redirect_uri or captcha_img")
- }
- resCh <- manualRes{t, k, e}
- }()
-
- select {
- case res := <-resCh:
- successToken = res.token
- captchaKey = res.key
- solveErr = res.err
- case <-manualCtx.Done():
+ deadlineExceeded := errors.Is(manualCtx.Err(), context.DeadlineExceeded)
+ manualCancel()
+ if solveErr != nil && deadlineExceeded {
solveErr = fmt.Errorf("manual captcha timed out after 60s")
}
- manualCancel()
}
// If solving failed (auto or manual) or timed out
if solveErr != nil {
- log.Printf("[STREAM %d] [Captcha] %s failed (attempt %d): %v", streamID, captchaSolveModeLabel(solveMode), attempt+1, solveErr)
+ log.Printf("[STREAM %d] [Captcha] %s failed (attempt %d): %v", streamID, capitalizeFirstLetter(captchaSolveModeLabel(solveMode)), captchaLogAttempt(ctx, solveMode, attempt), solveErr)
nextSolveMode, hasNextSolveMode := captchaSolveModeForAttempt(attempt+1, manualCaptcha, autoCaptchaSliderPOC)
if hasNextSolveMode {
@@ -1424,6 +1607,7 @@ func dtlsFunc(ctx context.Context, conn net.PacketConn, peer *net.UDPAddr) (net.
}
if err := dtlsConn.HandshakeContext(ctx1); err != nil {
+ _ = dtlsConn.Close()
return nil, err
}
return dtlsConn, nil
@@ -1436,6 +1620,15 @@ func oneDtlsConnection(ctx context.Context, peer *net.UDPAddr, listenConn net.Pa
defer dtlscancel()
conn1, conn2 := connutil.AsyncPacketPipe()
+ pipeConn := conn1
+ defer func() {
+ if pipeConn == nil {
+ return
+ }
+ if closeErr := pipeConn.Close(); closeErr != nil {
+ log.Printf("[STREAM %d] Failed to close DTLS pipe: %s", streamID, closeErr)
+ }
+ }()
go func() {
for {
select {
@@ -1449,6 +1642,7 @@ func oneDtlsConnection(ctx context.Context, peer *net.UDPAddr, listenConn net.Pa
if err1 != nil {
return fmt.Errorf("failed to connect DTLS: %s", err1)
}
+ pipeConn = nil
defer func() {
if closeErr := dtlsConn.Close(); closeErr != nil {
log.Printf("[STREAM %d] failed to close DTLS connection: %s", streamID, closeErr)
@@ -1713,6 +1907,10 @@ func oneTurnConnection(ctx context.Context, turnParams *turnParams, peer *net.UD
func oneDtlsConnectionLoop(ctx context.Context, peer *net.UDPAddr, listenConn net.PacketConn, inboundChan <-chan *UDPPacket, connchan chan<- net.PacketConn, okchan chan<- struct{}, streamID int) {
for {
+ if isCacheBucketDisabled(streamID) {
+ announceDisabledBucket(streamID)
+ return
+ }
select {
case <-ctx.Done():
return
@@ -1747,6 +1945,10 @@ func oneTurnConnectionLoop(ctx context.Context, turnParams *turnParams, peer *ne
go oneTurnConnection(ctx, turnParams, peer, conn2, streamID, c)
if err := <-c; err != nil {
+ if strings.Contains(err.Error(), errCaptchaBucketDisabled.Error()) {
+ announceDisabledBucket(streamID)
+ return
+ }
if strings.Contains(err.Error(), "FATAL_CAPTCHA") {
log.Printf("[STREAM %d] Fatal manual captcha error. Shutting down application.", streamID)
if globalAppCancel != nil {
@@ -1784,6 +1986,11 @@ func oneTurnConnectionLoop(ctx context.Context, turnParams *turnParams, peer *ne
}
func main() {
+ opts, exitCode := parseClientOptions(os.Args[1:], filepath.Base(os.Args[0]), os.Stdout, os.Stderr)
+ if exitCode != cliutil.ContinueExecution {
+ os.Exit(exitCode)
+ }
+
ctx, cancel := context.WithCancel(context.Background())
globalAppCancel = cancel
defer cancel()
@@ -1800,38 +2007,19 @@ func main() {
log.Fatalf("Exit...\n")
}()
- host := flag.String("turn", "", "override TURN server ip")
- port := flag.String("port", "", "override TURN port")
- listen := flag.String("listen", "127.0.0.1:9000", "listen on ip:port")
- vklink := flag.String("vk-link", "", "VK calls invite link \"https://vk.com/call/join/...\"")
- yalink := flag.String("yandex-link", "", "Yandex telemost invite link \"https://telemost.yandex.ru/j/...\"")
- peerAddr := flag.String("peer", "", "peer server address (host:port)")
- n := flag.Int("n", 0, "connections to TURN (default 10 for VK, 1 for Yandex)")
- udp := flag.Bool("udp", false, "connect to TURN with UDP")
- direct := flag.Bool("no-dtls", false, "connect without obfuscation. DO NOT USE")
- vlessMode := flag.Bool("vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets")
- debugFlag := flag.Bool("debug", false, "enable debug logging")
- manualCaptchaFlag := flag.Bool("manual-captcha", false, "skip auto captcha solving, use manual mode immediately")
- flag.Parse()
- if *peerAddr == "" {
- log.Panicf("Need peer address!")
- }
- peer, err := net.ResolveUDPAddr("udp", *peerAddr)
+ peer, err := net.ResolveUDPAddr("udp", opts.peerAddr)
if err != nil {
panic(err)
}
- if (*vklink == "") == (*yalink == "") {
- log.Panicf("Need either vk-link or yandex-link!")
- }
- isDebug = *debugFlag
- manualCaptcha = *manualCaptchaFlag
+ isDebug = opts.debug
+ manualCaptcha = opts.manualCaptcha
autoCaptchaSliderPOC = !manualCaptcha
var link string
var getCreds getCredsFunc
- if *vklink != "" {
- parts := strings.Split(*vklink, "join/")
+ if opts.vklink != "" {
+ parts := strings.Split(opts.vklink, "join/")
link = parts[len(parts)-1]
dialer := dnsdialer.New(
@@ -1843,17 +2031,17 @@ func main() {
getCreds = func(ctx context.Context, s string, streamID int) (string, string, string, error) {
return getVkCredsCached(ctx, s, streamID, dialer)
}
- if *n <= 0 {
- *n = 10
+ if opts.n <= 0 {
+ opts.n = 10
}
} else {
- parts := strings.Split(*yalink, "j/")
+ parts := strings.Split(opts.yalink, "j/")
link = parts[len(parts)-1]
getCreds = func(ctx context.Context, s string, streamID int) (string, string, string, error) {
return getYandexCreds(s)
}
- if *n <= 0 {
- *n = 1
+ if opts.n <= 0 {
+ opts.n = 1
}
}
if idx := strings.IndexAny(link, "/?#"); idx != -1 {
@@ -1861,19 +2049,19 @@ func main() {
}
params := &turnParams{
- host: *host,
- port: *port,
+ host: opts.host,
+ port: opts.port,
link: link,
- udp: *udp,
+ udp: opts.udp,
getCreds: getCreds,
}
- if *vlessMode {
- runVLESSMode(ctx, params, peer, *listen, *n)
+ if opts.vlessMode {
+ runVLESSMode(ctx, params, peer, opts.listen, opts.n)
return
}
- listenConn, err := net.ListenPacket("udp", *listen)
+ listenConn, err := net.ListenPacket("udp", opts.listen)
if err != nil {
log.Panicf("Failed to listen: %s", err)
}
@@ -1883,10 +2071,11 @@ func main() {
}
})
- numStreams := *n
+ numStreams := opts.n
if numStreams <= 0 {
numStreams = 1
}
+ configuredStreams.Store(int32(numStreams))
// Shared Worker Pool Queue for Aggregation
inboundChan := make(chan *UDPPacket, 2000)
@@ -1930,21 +2119,24 @@ func main() {
wg1 := sync.WaitGroup{}
t := time.Tick(200 * time.Millisecond)
- if *direct {
+ if opts.direct {
log.Panicf("Direct mode not supported with dispatcher")
}
okchan := make(chan struct{})
connchan := make(chan net.PacketConn)
+ streamCtx, streamCancel := context.WithCancel(ctx)
wg1.Add(1)
go func() {
defer wg1.Done()
- oneDtlsConnectionLoop(ctx, peer, listenConn, inboundChan, connchan, okchan, 1)
+ defer streamCancel()
+ oneDtlsConnectionLoop(streamCtx, peer, listenConn, inboundChan, connchan, okchan, 1)
}()
wg1.Add(1)
go func() {
defer wg1.Done()
- oneTurnConnectionLoop(ctx, params, peer, connchan, t, 1)
+ defer streamCancel()
+ oneTurnConnectionLoop(streamCtx, params, peer, connchan, t, 1)
}()
select {
@@ -1954,15 +2146,18 @@ func main() {
for i := 1; i < numStreams; i++ {
cchan := make(chan net.PacketConn)
+ streamCtx, streamCancel := context.WithCancel(ctx)
wg1.Add(1)
go func(streamID int) {
defer wg1.Done()
- oneDtlsConnectionLoop(ctx, peer, listenConn, inboundChan, cchan, nil, streamID)
+ defer streamCancel()
+ oneDtlsConnectionLoop(streamCtx, peer, listenConn, inboundChan, cchan, nil, streamID)
}(i)
wg1.Add(1)
go func(streamID int) {
defer wg1.Done()
- oneTurnConnectionLoop(ctx, params, peer, cchan, t, streamID)
+ defer streamCancel()
+ oneTurnConnectionLoop(streamCtx, params, peer, cchan, t, streamID)
}(i)
}
@@ -2246,16 +2441,19 @@ func createSmuxSession(ctx context.Context, tp *turnParams, peer *net.UDPAddr, i
cleanup()
return nil, nil, fmt.Errorf("DTLS handshake: %w", err)
}
- cleanupFns = append(cleanupFns, func() { _ = dtlsConn.Close() })
log.Printf("DTLS connection established")
// 5. Create KCP session over DTLS
- kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, false)
+ kcpSess, cleanupKCP, err := tcputil.NewKCPOverDTLS(dtlsConn, false)
if err != nil {
cleanup()
return nil, nil, fmt.Errorf("KCP session: %w", err)
}
- cleanupFns = append(cleanupFns, func() { _ = kcpSess.Close() })
+ cleanupFns = append(cleanupFns, func() {
+ if err := cleanupKCP(); err != nil {
+ log.Printf("KCP cleanup error: %v", err)
+ }
+ })
log.Printf("KCP session established")
// 6. Create smux client session over KCP
@@ -2326,3 +2524,10 @@ func pipe(ctx context.Context, c1, c2 net.Conn) {
log.Printf("pipe: failed to reset deadline c2: %v", err)
}
}
+
+func capitalizeFirstLetter(s string) string {
+ for i, v := range s {
+ return string(unicode.ToUpper(v)) + s[i+len(string(v)):]
+ }
+ return ""
+}
diff --git a/client/main_test.go b/client/main_test.go
index e0f2d779..9eb72cd4 100644
--- a/client/main_test.go
+++ b/client/main_test.go
@@ -1,6 +1,95 @@
package main
-import "testing"
+import (
+ "bytes"
+ "context"
+ "fmt"
+ "strings"
+ "sync"
+ "testing"
+ "time"
+
+ "github.com/bschaatsbergen/dnsdialer"
+ "github.com/cacggghp/vk-turn-proxy/internal/cliutil"
+)
+
+func TestParseClientOptionsShowsUsageWithoutArgs(t *testing.T) {
+ t.Parallel()
+
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+
+ _, exitCode := parseClientOptions(nil, "client", &stdout, &stderr)
+ if exitCode != 0 {
+ t.Fatalf("parseClientOptions() exitCode = %d, want 0", exitCode)
+ }
+ if stderr.Len() != 0 {
+ t.Fatalf("expected no stderr output, got %q", stderr.String())
+ }
+ if got := stdout.String(); !strings.Contains(got, "Usage:\n client -peer -vk-link [flags]") {
+ t.Fatalf("usage output missing client help text: %q", got)
+ }
+}
+
+func TestParseClientOptionsShowsHelpFlagUsage(t *testing.T) {
+ t.Parallel()
+
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+
+ _, exitCode := parseClientOptions([]string{"-help"}, "client", &stdout, &stderr)
+ if exitCode != 0 {
+ t.Fatalf("parseClientOptions() exitCode = %d, want 0", exitCode)
+ }
+ if stderr.Len() != 0 {
+ t.Fatalf("expected no stderr output, got %q", stderr.String())
+ }
+ if got := stdout.String(); !strings.Contains(got, "Examples:") {
+ t.Fatalf("expected help examples in output, got %q", got)
+ }
+}
+
+func TestParseClientOptionsRequiresPeer(t *testing.T) {
+ t.Parallel()
+
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+
+ _, exitCode := parseClientOptions([]string{"-vk-link", "https://vk.com/call/join/test"}, "client", &stdout, &stderr)
+ if exitCode != 2 {
+ t.Fatalf("parseClientOptions() exitCode = %d, want 2", exitCode)
+ }
+ if stdout.Len() != 0 {
+ t.Fatalf("expected no stdout output, got %q", stdout.String())
+ }
+ if got := stderr.String(); !strings.Contains(got, "error: -peer is required") {
+ t.Fatalf("expected missing peer error, got %q", got)
+ }
+}
+
+func TestParseClientOptionsParsesValidVKArgs(t *testing.T) {
+ t.Parallel()
+
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+
+ opts, exitCode := parseClientOptions([]string{"-peer", "127.0.0.1:56000", "-vk-link", "https://vk.com/call/join/test", "-listen", "127.0.0.1:9001"}, "client", &stdout, &stderr)
+ if exitCode != cliutil.ContinueExecution {
+ t.Fatalf("parseClientOptions() exitCode = %d, want %d", exitCode, cliutil.ContinueExecution)
+ }
+ if stderr.Len() != 0 {
+ t.Fatalf("expected no stderr output, got %q", stderr.String())
+ }
+ if opts.peerAddr != "127.0.0.1:56000" {
+ t.Fatalf("peerAddr = %q, want 127.0.0.1:56000", opts.peerAddr)
+ }
+ if opts.vklink != "https://vk.com/call/join/test" {
+ t.Fatalf("vklink = %q, want VK link", opts.vklink)
+ }
+ if opts.listen != "127.0.0.1:9001" {
+ t.Fatalf("listen = %q, want 127.0.0.1:9001", opts.listen)
+ }
+}
func TestCaptchaSolveModeForAttempt(t *testing.T) {
t.Parallel()
@@ -59,3 +148,105 @@ func TestCaptchaSolveModeForAttempt(t *testing.T) {
}
})
}
+
+func TestCaptchaLogAttempt(t *testing.T) {
+ t.Parallel()
+
+ if got := captchaLogAttempt(context.Background(), captchaSolveModeManual, 0); got != 1 {
+ t.Fatalf("captchaLogAttempt() = %d, want 1 for first manual attempt", got)
+ }
+
+ ctx := withCaptchaFailureCount(context.Background(), 1)
+ if got := captchaLogAttempt(ctx, captchaSolveModeManual, 0); got != 2 {
+ t.Fatalf("captchaLogAttempt() = %d, want 2 for second bucket manual attempt", got)
+ }
+
+ if got := captchaLogAttempt(ctx, captchaSolveModeManual, 2); got != 3 {
+ t.Fatalf("captchaLogAttempt() = %d, want 3 when in-request attempt count is already higher", got)
+ }
+
+ if got := captchaLogAttempt(ctx, captchaSolveModeAuto, 0); got != 1 {
+ t.Fatalf("captchaLogAttempt() = %d, want 1 for auto attempt count", got)
+ }
+}
+
+func TestGetVkCredsCachedDisablesBucketAfterTwoCaptchaFailures(t *testing.T) {
+ credentialsStore.mu.Lock()
+ previousCaches := credentialsStore.caches
+ credentialsStore.caches = make(map[int]*StreamCredentialsCache)
+ credentialsStore.mu.Unlock()
+ defer func() {
+ credentialsStore.mu.Lock()
+ credentialsStore.caches = previousCaches
+ credentialsStore.mu.Unlock()
+ }()
+
+ previousFetch := fetchVkCredsSerializedFunc
+ defer func() {
+ fetchVkCredsSerializedFunc = previousFetch
+ }()
+ previousConfiguredStreams := configuredStreams.Load()
+ configuredStreams.Store(20)
+ defer configuredStreams.Store(previousConfiguredStreams)
+
+ var (
+ mu sync.Mutex
+ callCount int
+ )
+ fetchVkCredsSerializedFunc = func(ctx context.Context, link string, streamID int, dialer *dnsdialer.Dialer) (string, string, string, error) {
+ mu.Lock()
+ defer mu.Unlock()
+ callCount++
+ switch callCount {
+ case 1:
+ return "", "", "", fmt.Errorf("CAPTCHA_WAIT_REQUIRED")
+ case 2:
+ return "", "", "", fmt.Errorf("CAPTCHA_WAIT_REQUIRED")
+ default:
+ return "", "", "", fmt.Errorf("unexpected extra fetch for stream %d", streamID)
+ }
+ }
+
+ _, _, _, err := getVkCredsCached(context.Background(), "link", 10, nil)
+ if err == nil || err.Error() != "CAPTCHA_WAIT_REQUIRED" {
+ t.Fatalf("first getVkCredsCached() error = %v, want CAPTCHA_WAIT_REQUIRED", err)
+ }
+
+ _, _, _, err = getVkCredsCached(context.Background(), "link", 17, nil)
+ if err == nil || err.Error() != "CAPTCHA_WAIT_REQUIRED" {
+ t.Fatalf("second getVkCredsCached() error = %v, want shared CAPTCHA_WAIT_REQUIRED", err)
+ }
+
+ mu.Lock()
+ if callCount != 1 {
+ mu.Unlock()
+ t.Fatalf("expected one fetch attempt during shared captcha cooldown, got %d", callCount)
+ }
+ mu.Unlock()
+
+ cache := getStreamCache(10)
+ cache.mutex.Lock()
+ cache.retryAfter = time.Now().Add(-time.Second)
+ cache.mutex.Unlock()
+
+ _, _, _, err = getVkCredsCached(context.Background(), "link", 19, nil)
+ if err == nil || !strings.Contains(err.Error(), errCaptchaBucketDisabled.Error()) {
+ t.Fatalf("third getVkCredsCached() error = %v, want bucket disabled", err)
+ }
+
+ _, _, _, err = getVkCredsCached(context.Background(), "link", 11, nil)
+ if err == nil || !strings.Contains(err.Error(), errCaptchaBucketDisabled.Error()) {
+ t.Fatalf("fourth getVkCredsCached() error = %v, want bucket disabled without new fetch", err)
+ }
+
+ mu.Lock()
+ if callCount != 2 {
+ mu.Unlock()
+ t.Fatalf("expected exactly two fetch attempts before bucket disable, got %d", callCount)
+ }
+ mu.Unlock()
+
+ if got := activeConfiguredStreamCount(); got != 10 {
+ t.Fatalf("activeConfiguredStreamCount() = %d, want 10 after disabling second bucket", got)
+ }
+}
diff --git a/client/manual_captcha.go b/client/manual_captcha.go
index 27f958d3..5fa5ae50 100644
--- a/client/manual_captcha.go
+++ b/client/manual_captcha.go
@@ -23,6 +23,8 @@ import (
const captchaListenPort = "8765"
+var openBrowserFunc = openBrowser
+
type browserCommand struct {
name string
args []string
@@ -374,13 +376,20 @@ func startCaptchaServer(srv *http.Server, logPrefix string) error {
return fmt.Errorf("captcha listeners failed: %s", strings.Join(listenErrs, "; "))
}
-// runCaptchaServerAndWait triggers the browser, and waiting gracefully for the solution token.
-func runCaptchaServerAndWait(handler http.Handler, captchaURL string, keyCh <-chan string, logPrefix string) (string, error) {
+// runCaptchaServerAndWait triggers the browser and waits for either a solution token or cancellation.
+func runCaptchaServerAndWait(ctx context.Context, handler http.Handler, captchaURL string, keyCh <-chan string, logPrefix string) (string, error) {
srv := &http.Server{Handler: handler}
if err := startCaptchaServer(srv, logPrefix); err != nil {
return "", err
}
+ defer func() {
+ shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+ defer cancel()
+ if err := srv.Shutdown(shutdownCtx); err != nil && !errors.Is(err, http.ErrServerClosed) {
+ log.Printf("%s shutdown error: %v", logPrefix, err)
+ }
+ }()
fmt.Println("\n==============================================")
fmt.Println("ACTION REQUIRED: MANUAL CAPTCHA SOLVING NEEDED")
@@ -388,17 +397,14 @@ func runCaptchaServerAndWait(handler http.Handler, captchaURL string, keyCh <-ch
fmt.Println("==============================================")
fmt.Println()
- openBrowser(captchaURL)
-
- key := <-keyCh
+ openBrowserFunc(captchaURL)
- ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
- defer cancel()
- if err := srv.Shutdown(ctx); err != nil {
- return "", err
+ select {
+ case key := <-keyCh:
+ return key, nil
+ case <-ctx.Done():
+ return "", ctx.Err()
}
-
- return key, nil
}
// notifyKey pushes the key string to the given channel without blocking
@@ -411,7 +417,7 @@ func notifyKey(keyCh chan<- string, key string) {
}
}
-func solveCaptchaViaHTTP(captchaImg string) (string, error) {
+func solveCaptchaViaHTTP(ctx context.Context, captchaImg string) (string, error) {
keyCh := make(chan string, 1)
mux := http.NewServeMux()
@@ -439,10 +445,10 @@ button{font-size:24px;padding:12px 32px;margin-top:12px;cursor:pointer}
_, _ = fmt.Fprint(w, `Done!
`)
})
- return runCaptchaServerAndWait(mux, localCaptchaOrigin(), keyCh, "captcha HTTP server error")
+ return runCaptchaServerAndWait(ctx, mux, localCaptchaOrigin(), keyCh, "captcha HTTP server error")
}
-func solveCaptchaViaProxy(redirectURI string, dialer *dnsdialer.Dialer) (string, error) {
+func solveCaptchaViaProxy(ctx context.Context, redirectURI string, dialer *dnsdialer.Dialer) (string, error) {
keyCh := make(chan string, 1)
targetURL, err := neturl.Parse(redirectURI)
@@ -566,7 +572,7 @@ func solveCaptchaViaProxy(redirectURI string, dialer *dnsdialer.Dialer) (string,
proxy.ServeHTTP(w, r)
})
- return runCaptchaServerAndWait(mux, localCaptchaURLForTarget(targetURL), keyCh, "proxy HTTP server error")
+ return runCaptchaServerAndWait(ctx, mux, localCaptchaURLForTarget(targetURL), keyCh, "proxy HTTP server error")
}
func openBrowser(url string) {
diff --git a/client/manual_captcha_test.go b/client/manual_captcha_test.go
index 8afbafdb..9294d731 100644
--- a/client/manual_captcha_test.go
+++ b/client/manual_captcha_test.go
@@ -1,8 +1,13 @@
package main
import (
+ "context"
+ "errors"
+ "net"
+ "net/http"
"net/url"
"testing"
+ "time"
)
func TestRewriteProxyRedirectLocation(t *testing.T) {
@@ -63,3 +68,75 @@ func TestRewriteProxyRedirectLocation(t *testing.T) {
})
}
}
+
+func TestRunCaptchaServerAndWaitStopsOnContextCancel(t *testing.T) {
+ listener, err := net.Listen("tcp", "127.0.0.1:"+captchaListenPort)
+ if err != nil {
+ t.Skipf("captcha listener test requires a free localhost port: %v", err)
+ }
+ if err := listener.Close(); err != nil {
+ t.Fatalf("failed to release preflight listener: %v", err)
+ }
+
+ previousOpenBrowser := openBrowserFunc
+ openBrowserFunc = func(string) {}
+ defer func() {
+ openBrowserFunc = previousOpenBrowser
+ }()
+
+ ctx, cancel := context.WithCancel(context.Background())
+ errCh := make(chan error, 1)
+
+ go func() {
+ _, err := runCaptchaServerAndWait(
+ ctx,
+ http.NewServeMux(),
+ localCaptchaOrigin(),
+ make(chan string),
+ "test captcha server",
+ )
+ errCh <- err
+ }()
+
+ deadline := time.Now().Add(2 * time.Second)
+ for {
+ conn, dialErr := net.DialTimeout("tcp", "127.0.0.1:"+captchaListenPort, 50*time.Millisecond)
+ if dialErr == nil {
+ if err := conn.Close(); err != nil {
+ t.Fatalf("failed to close probe connection: %v", err)
+ }
+ break
+ }
+ if time.Now().After(deadline) {
+ cancel()
+ t.Fatalf("captcha server did not start listening: %v", dialErr)
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+
+ cancel()
+
+ select {
+ case err := <-errCh:
+ if !errors.Is(err, context.Canceled) {
+ t.Fatalf("runCaptchaServerAndWait() error = %v, want context canceled", err)
+ }
+ case <-time.After(3 * time.Second):
+ t.Fatal("runCaptchaServerAndWait() did not return after context cancellation")
+ }
+
+ deadline = time.Now().Add(2 * time.Second)
+ for {
+ listener, err = net.Listen("tcp", "127.0.0.1:"+captchaListenPort)
+ if err == nil {
+ if closeErr := listener.Close(); closeErr != nil {
+ t.Fatalf("failed to close verification listener: %v", closeErr)
+ }
+ return
+ }
+ if time.Now().After(deadline) {
+ t.Fatalf("captcha listener was not released after cancellation: %v", err)
+ }
+ time.Sleep(20 * time.Millisecond)
+ }
+}
diff --git a/client/slider_captcha.go b/client/slider_captcha.go
index 166a6abd..1cadd515 100644
--- a/client/slider_captcha.go
+++ b/client/slider_captcha.go
@@ -144,7 +144,7 @@ func (s *captchaNotRobotSession) requestComponentDone() error {
respObj, ok := resp["response"].(map[string]interface{})
if ok {
- if status, _ := respObj["status"].(string); status != "" && status != "OK" {
+ if status, ok := respObj["status"].(string); ok && status != "" && status != "OK" {
return fmt.Errorf("componentDone status: %s", status)
}
}
@@ -316,7 +316,9 @@ func parseCaptchaSettingsResponse(resp map[string]interface{}) (*captchaSettings
settings := &captchaSettingsResponse{
SettingsByType: make(map[string]string),
}
- settings.ShowCaptchaType, _ = respObj["show_captcha_type"].(string)
+ if showCaptchaType, ok := respObj["show_captcha_type"].(string); ok {
+ settings.ShowCaptchaType = showCaptchaType
+ }
rawSettings, ok := expandCaptchaSettings(respObj["captcha_settings"])
if !ok {
@@ -329,8 +331,8 @@ func parseCaptchaSettingsResponse(resp map[string]interface{}) (*captchaSettings
continue
}
- captchaType, _ := item["type"].(string)
- if captchaType == "" {
+ captchaType, ok := item["type"].(string)
+ if !ok || captchaType == "" {
continue
}
@@ -495,9 +497,15 @@ func parseCaptchaCheckResult(resp map[string]interface{}) (*captchaCheckResult,
}
result := &captchaCheckResult{}
- result.Status, _ = respObj["status"].(string)
- result.SuccessToken, _ = respObj["success_token"].(string)
- result.ShowCaptchaType, _ = respObj["show_captcha_type"].(string)
+ if status, ok := respObj["status"].(string); ok {
+ result.Status = status
+ }
+ if successToken, ok := respObj["success_token"].(string); ok {
+ result.SuccessToken = successToken
+ }
+ if showCaptchaType, ok := respObj["show_captcha_type"].(string); ok {
+ result.ShowCaptchaType = showCaptchaType
+ }
if result.Status == "" {
return nil, fmt.Errorf("check status missing: %v", resp)
}
@@ -511,19 +519,22 @@ func parseSliderCaptchaContentResponse(resp map[string]interface{}) (*sliderCapt
return nil, fmt.Errorf("invalid slider content response: %v", resp)
}
- status, _ := respObj["status"].(string)
- if status != "OK" {
+ status, ok := respObj["status"].(string)
+ if !ok || status != "OK" {
return nil, fmt.Errorf("slider getContent status: %s", status)
}
- extension, _ := respObj["extension"].(string)
+ extension, ok := respObj["extension"].(string)
+ if !ok {
+ return nil, fmt.Errorf("unsupported slider image format: %v", respObj["extension"])
+ }
extension = strings.ToLower(extension)
if extension != "jpeg" && extension != "jpg" {
return nil, fmt.Errorf("unsupported slider image format: %s", extension)
}
- rawImage, _ := respObj["image"].(string)
- if rawImage == "" {
+ rawImage, ok := respObj["image"].(string)
+ if !ok || rawImage == "" {
return nil, fmt.Errorf("slider image missing")
}
diff --git a/internal/cliutil/cliutil.go b/internal/cliutil/cliutil.go
new file mode 100644
index 00000000..bb5282be
--- /dev/null
+++ b/internal/cliutil/cliutil.go
@@ -0,0 +1,71 @@
+package cliutil
+
+import (
+ "errors"
+ "flag"
+ "fmt"
+ "io"
+)
+
+const ContinueExecution = -1
+
+type Builder[T any] func(program string, output io.Writer) (*flag.FlagSet, *T)
+
+type Validator[T any] func(*T) error
+
+func Parse[T any](
+ args []string,
+ program string,
+ stdout io.Writer,
+ stderr io.Writer,
+ build Builder[T],
+ validate Validator[T],
+) (T, int) {
+ var zero T
+
+ if len(args) == 0 {
+ fs, _ := build(program, stdout)
+ fs.Usage()
+ return zero, 0
+ }
+
+ output := stderr
+ if hasHelpFlag(args) {
+ output = stdout
+ }
+
+ fs, opts := build(program, output)
+ if err := fs.Parse(args); err != nil {
+ if errors.Is(err, flag.ErrHelp) {
+ return zero, 0
+ }
+ return zero, 2
+ }
+
+ if validate != nil {
+ if err := validate(opts); err != nil {
+ Fprintln(stderr, "error:", err)
+ fs.Usage()
+ return zero, 2
+ }
+ }
+
+ return *opts, ContinueExecution
+}
+
+func Fprintf(w io.Writer, format string, args ...any) {
+ _, _ = fmt.Fprintf(w, format, args...)
+}
+
+func Fprintln(w io.Writer, args ...any) {
+ _, _ = fmt.Fprintln(w, args...)
+}
+
+func hasHelpFlag(args []string) bool {
+ for _, arg := range args {
+ if arg == "-h" || arg == "-help" || arg == "--help" {
+ return true
+ }
+ }
+ return false
+}
diff --git a/server/main.go b/server/main.go
index cdf91d43..e02c3a42 100644
--- a/server/main.go
+++ b/server/main.go
@@ -2,6 +2,7 @@ package main
import (
"context"
+ "errors"
"flag"
"fmt"
"io"
@@ -9,21 +10,58 @@ import (
"net"
"os"
"os/signal"
+ "path/filepath"
"sync"
"syscall"
"time"
+ "github.com/cacggghp/vk-turn-proxy/internal/cliutil"
"github.com/cacggghp/vk-turn-proxy/tcputil"
"github.com/pion/dtls/v3"
"github.com/pion/dtls/v3/pkg/crypto/selfsign"
"github.com/xtaci/smux"
)
+type serverOptions struct {
+ listen string
+ connect string
+ vlessMode bool
+}
+
+func newServerFlagSet(program string, output io.Writer) (*flag.FlagSet, *serverOptions) {
+ fs := flag.NewFlagSet(program, flag.ContinueOnError)
+ fs.SetOutput(output)
+
+ opts := &serverOptions{}
+ fs.StringVar(&opts.listen, "listen", "0.0.0.0:56000", "listen on ip:port")
+ fs.StringVar(&opts.connect, "connect", "", "connect to ip:port")
+ fs.BoolVar(&opts.vlessMode, "vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets")
+ fs.Usage = func() {
+ cliutil.Fprintf(fs.Output(), "Usage:\n %s -connect [flags]\n\n", program)
+ cliutil.Fprintln(fs.Output(), "Examples:")
+ cliutil.Fprintf(fs.Output(), " %s -connect 127.0.0.1:51820\n", program)
+ cliutil.Fprintf(fs.Output(), " %s -listen 0.0.0.0:56000 -connect 127.0.0.1:51820 -vless\n\n", program)
+ cliutil.Fprintln(fs.Output(), "Flags:")
+ fs.PrintDefaults()
+ }
+
+ return fs, opts
+}
+
+func parseServerOptions(args []string, program string, stdout, stderr io.Writer) (serverOptions, int) {
+ return cliutil.Parse(args, program, stdout, stderr, newServerFlagSet, func(opts *serverOptions) error {
+ if opts.connect == "" {
+ return fmt.Errorf("-connect is required")
+ }
+ return nil
+ })
+}
+
func main() {
- listen := flag.String("listen", "0.0.0.0:56000", "listen on ip:port")
- connect := flag.String("connect", "", "connect to ip:port")
- vlessMode := flag.Bool("vless", false, "VLESS mode: forward TCP connections (for VLESS) instead of UDP packets")
- flag.Parse()
+ opts, exitCode := parseServerOptions(os.Args[1:], filepath.Base(os.Args[0]), os.Stdout, os.Stderr)
+ if exitCode != cliutil.ContinueExecution {
+ os.Exit(exitCode)
+ }
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
@@ -37,23 +75,16 @@ func main() {
log.Fatalf("Exit...\n")
}()
- addr, err := net.ResolveUDPAddr("udp", *listen)
+ addr, err := net.ResolveUDPAddr("udp", opts.listen)
if err != nil {
panic(err)
}
- if len(*connect) == 0 {
- log.Panicf("server address is required")
- }
// Generate a certificate and private key to secure the connection
certificate, genErr := selfsign.GenerateSelfSigned()
if genErr != nil {
panic(genErr)
}
- //
- // Everything below is the pion-DTLS API! Thanks for using it ❤️.
- //
-
// Connect to a DTLS server
listener, err := dtls.ListenWithOptions(
"udp",
@@ -72,7 +103,7 @@ func main() {
}
})
- fmt.Println("Listening")
+ cliutil.Fprintln(os.Stdout, "Listening")
wg1 := sync.WaitGroup{}
for {
@@ -91,7 +122,11 @@ func main() {
wg1.Add(1)
go func(conn net.Conn) {
defer wg1.Done()
+ closeConn := true
defer func() {
+ if !closeConn {
+ return
+ }
if closeErr := conn.Close(); closeErr != nil {
log.Printf("failed to close incoming connection: %s", closeErr)
}
@@ -114,10 +149,11 @@ func main() {
}
log.Println("Handshake done")
- if *vlessMode {
- handleVLESSConnection(ctx, dtlsConn, *connect)
+ if opts.vlessMode {
+ closeConn = false
+ handleVLESSConnection(ctx, dtlsConn, opts.connect)
} else {
- handleUDPConnection(ctx, conn, *connect)
+ handleUDPConnection(ctx, conn, opts.connect)
}
log.Printf("Connection closed: %s\n", conn.RemoteAddr())
@@ -149,83 +185,24 @@ func handleUDPConnection(ctx context.Context, conn net.Conn, connectAddr string)
log.Printf("failed to set outgoing deadline: %s", err)
}
})
- go func() {
- defer wg.Done()
- defer cancel2()
- buf := make([]byte, 1600)
- for {
- select {
- case <-ctx2.Done():
- return
- default:
- }
- if err1 := conn.SetReadDeadline(time.Now().Add(time.Minute * 30)); err1 != nil {
- log.Printf("Failed: %s", err1)
- return
- }
- n, err1 := conn.Read(buf)
- if err1 != nil {
- log.Printf("Failed: %s", err1)
- return
- }
-
- if err1 = serverConn.SetWriteDeadline(time.Now().Add(time.Minute * 30)); err1 != nil {
- log.Printf("Failed: %s", err1)
- return
- }
- _, err1 = serverConn.Write(buf[:n])
- if err1 != nil {
- log.Printf("Failed: %s", err1)
- return
- }
- }
- }()
- go func() {
- defer wg.Done()
- defer cancel2()
- buf := make([]byte, 1600)
- for {
- select {
- case <-ctx2.Done():
- return
- default:
- }
- if err1 := serverConn.SetReadDeadline(time.Now().Add(time.Minute * 30)); err1 != nil {
- log.Printf("Failed: %s", err1)
- return
- }
- n, err1 := serverConn.Read(buf)
- if err1 != nil {
- log.Printf("Failed: %s", err1)
- return
- }
-
- if err1 = conn.SetWriteDeadline(time.Now().Add(time.Minute * 30)); err1 != nil {
- log.Printf("Failed: %s", err1)
- return
- }
- _, err1 = conn.Write(buf[:n])
- if err1 != nil {
- log.Printf("Failed: %s", err1)
- return
- }
- }
- }()
+ startPacketForwarder(ctx2, &wg, cancel2, conn, serverConn)
+ startPacketForwarder(ctx2, &wg, cancel2, serverConn, conn)
wg.Wait()
}
// handleVLESSConnection creates a KCP+smux session over DTLS and forwards
// each smux stream as a TCP connection to the backend (Xray/VLESS).
+// It takes ownership of dtlsConn and closes it through the KCP cleanup path.
func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr string) {
// 1. Create KCP session over DTLS
- kcpSess, err := tcputil.NewKCPOverDTLS(dtlsConn, true)
+ kcpSess, cleanupKCP, err := tcputil.NewKCPOverDTLS(dtlsConn, true)
if err != nil {
log.Printf("KCP session error: %s", err)
return
}
defer func() {
- if err := kcpSess.Close(); err != nil {
- log.Printf("failed to close KCP session: %v", err)
+ if err := cleanupKCP(); err != nil {
+ log.Printf("failed to close KCP-over-DTLS transport: %v", err)
}
}()
log.Printf("KCP session established (server)")
@@ -261,7 +238,7 @@ func handleVLESSConnection(ctx context.Context, dtlsConn net.Conn, connectAddr s
defer wg.Done()
defer func() {
- if err := s.Close(); err != nil && err != smux.ErrGoAway {
+ if err := s.Close(); err != nil && !errors.Is(err, smux.ErrGoAway) {
log.Printf("failed to close smux stream: %v", err)
}
}()
@@ -301,24 +278,58 @@ func pipeConn(ctx context.Context, c1, c2 net.Conn) {
var wg sync.WaitGroup
wg.Add(2)
+ startStreamCopy(&wg, cancel, c1, c2, "pipeConn: c1<-c2")
+ startStreamCopy(&wg, cancel, c2, c1, "pipeConn: c2<-c1")
+
+ wg.Wait()
+ // Reset deadlines
+ _ = c1.SetDeadline(time.Time{})
+ _ = c2.SetDeadline(time.Time{})
+}
+
+func startPacketForwarder(ctx context.Context, wg *sync.WaitGroup, cancel context.CancelFunc, src, dst net.Conn) {
go func() {
defer wg.Done()
- if _, err := io.Copy(c1, c2); err != nil {
- log.Printf("pipeConn: c1<-c2 copy error: %v", err)
+ defer cancel()
+
+ buf := make([]byte, 1600)
+ for {
+ select {
+ case <-ctx.Done():
+ return
+ default:
+ }
+
+ if err := src.SetReadDeadline(time.Now().Add(30 * time.Minute)); err != nil {
+ log.Printf("Failed: %s", err)
+ return
+ }
+ n, err := src.Read(buf)
+ if err != nil {
+ log.Printf("Failed: %s", err)
+ return
+ }
+
+ if err = dst.SetWriteDeadline(time.Now().Add(30 * time.Minute)); err != nil {
+ log.Printf("Failed: %s", err)
+ return
+ }
+ if _, err = dst.Write(buf[:n]); err != nil {
+ log.Printf("Failed: %s", err)
+ return
+ }
}
}()
+}
+func startStreamCopy(wg *sync.WaitGroup, cancel context.CancelFunc, dst, src net.Conn, label string) {
go func() {
defer wg.Done()
- if _, err := io.Copy(c2, c1); err != nil {
- log.Printf("pipeConn: c2<-c1 copy error: %v", err)
+ defer cancel()
+
+ if _, err := io.Copy(dst, src); err != nil {
+ log.Printf("%s copy error: %v", label, err)
}
}()
-
- wg.Wait()
-
- // Reset deadlines
- _ = c1.SetDeadline(time.Time{})
- _ = c2.SetDeadline(time.Time{})
}
diff --git a/server/main_test.go b/server/main_test.go
new file mode 100644
index 00000000..01e87074
--- /dev/null
+++ b/server/main_test.go
@@ -0,0 +1,87 @@
+package main
+
+import (
+ "bytes"
+ "strings"
+ "testing"
+
+ "github.com/cacggghp/vk-turn-proxy/internal/cliutil"
+)
+
+func TestParseServerOptionsShowsUsageWithoutArgs(t *testing.T) {
+ t.Parallel()
+
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+
+ _, exitCode := parseServerOptions(nil, "server", &stdout, &stderr)
+ if exitCode != 0 {
+ t.Fatalf("parseServerOptions() exitCode = %d, want 0", exitCode)
+ }
+ if stderr.Len() != 0 {
+ t.Fatalf("expected no stderr output, got %q", stderr.String())
+ }
+ if got := stdout.String(); !strings.Contains(got, "Usage:\n server -connect [flags]") {
+ t.Fatalf("usage output missing server help text: %q", got)
+ }
+}
+
+func TestParseServerOptionsShowsHelpFlagUsage(t *testing.T) {
+ t.Parallel()
+
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+
+ _, exitCode := parseServerOptions([]string{"-help"}, "server", &stdout, &stderr)
+ if exitCode != 0 {
+ t.Fatalf("parseServerOptions() exitCode = %d, want 0", exitCode)
+ }
+ if stderr.Len() != 0 {
+ t.Fatalf("expected no stderr output, got %q", stderr.String())
+ }
+ if got := stdout.String(); !strings.Contains(got, "Examples:") {
+ t.Fatalf("expected help examples in output, got %q", got)
+ }
+}
+
+func TestParseServerOptionsRequiresConnect(t *testing.T) {
+ t.Parallel()
+
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+
+ _, exitCode := parseServerOptions([]string{"-listen", "0.0.0.0:56000"}, "server", &stdout, &stderr)
+ if exitCode != 2 {
+ t.Fatalf("parseServerOptions() exitCode = %d, want 2", exitCode)
+ }
+ if stdout.Len() != 0 {
+ t.Fatalf("expected no stdout output, got %q", stdout.String())
+ }
+ if got := stderr.String(); !strings.Contains(got, "error: -connect is required") {
+ t.Fatalf("expected missing connect error, got %q", got)
+ }
+}
+
+func TestParseServerOptionsParsesValidArgs(t *testing.T) {
+ t.Parallel()
+
+ var stdout bytes.Buffer
+ var stderr bytes.Buffer
+
+ opts, exitCode := parseServerOptions([]string{"-connect", "127.0.0.1:51820", "-listen", "0.0.0.0:56000", "-vless"}, "server", &stdout, &stderr)
+ if exitCode != cliutil.ContinueExecution {
+ t.Fatalf("parseServerOptions() exitCode = %d, want %d", exitCode, cliutil.ContinueExecution)
+ }
+ if stderr.Len() != 0 {
+ t.Fatalf("expected no stderr output, got %q", stderr.String())
+ }
+ if opts.connect != "127.0.0.1:51820" {
+ t.Fatalf("connect = %q, want 127.0.0.1:51820", opts.connect)
+ }
+ if opts.listen != "0.0.0.0:56000" {
+ t.Fatalf("listen = %q, want 0.0.0.0:56000", opts.listen)
+ }
+ if !opts.vlessMode {
+ t.Fatal("vlessMode = false, want true")
+ }
+}
diff --git a/tcputil/tcputil.go b/tcputil/tcputil.go
index 4e37ea33..40d0684d 100644
--- a/tcputil/tcputil.go
+++ b/tcputil/tcputil.go
@@ -1,83 +1,117 @@
package tcputil
import (
+ "errors"
+ "io"
"net"
+ "sync"
"time"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
)
-// DtlsPacketConn wraps a net.Conn (DTLS) as a net.PacketConn for KCP.
-// Each DTLS Read/Write preserves message boundaries (datagram semantics).
-type DtlsPacketConn struct {
+// dtlsPacketConn adapts a DTLS net.Conn to net.PacketConn for KCP.
+// It does not own the underlying transport; callers must close dtlsConn
+// through the cleanup function returned by NewKCPOverDTLS.
+type dtlsPacketConn struct {
conn net.Conn
}
-func NewDtlsPacketConn(conn net.Conn) *DtlsPacketConn {
- return &DtlsPacketConn{conn: conn}
-}
-
-func (d *DtlsPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
+func (d *dtlsPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
n, err := d.conn.Read(b)
return n, d.conn.RemoteAddr(), err
}
-func (d *DtlsPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) {
+func (d *dtlsPacketConn) WriteTo(b []byte, _ net.Addr) (int, error) {
return d.conn.Write(b)
}
-func (d *DtlsPacketConn) Close() error {
- return d.conn.Close()
+func (d *dtlsPacketConn) Close() error {
+ return nil
}
-func (d *DtlsPacketConn) LocalAddr() net.Addr {
+func (d *dtlsPacketConn) LocalAddr() net.Addr {
return d.conn.LocalAddr()
}
-func (d *DtlsPacketConn) SetDeadline(t time.Time) error {
+func (d *dtlsPacketConn) SetDeadline(t time.Time) error {
return d.conn.SetDeadline(t)
}
-func (d *DtlsPacketConn) SetReadDeadline(t time.Time) error {
+func (d *dtlsPacketConn) SetReadDeadline(t time.Time) error {
return d.conn.SetReadDeadline(t)
}
-func (d *DtlsPacketConn) SetWriteDeadline(t time.Time) error {
+func (d *dtlsPacketConn) SetWriteDeadline(t time.Time) error {
return d.conn.SetWriteDeadline(t)
}
-// NewKCPOverDTLS creates a KCP session over a DTLS connection.
+// NewKCPOverDTLS creates a KCP session over a DTLS connection and returns
+// an idempotent cleanup function for the entire KCP-over-DTLS transport.
+// After a successful call, the caller should use the returned cleanup instead
+// of closing dtlsConn directly.
+//
// isServer: true for server-side (listener), false for client-side (dialer).
-func NewKCPOverDTLS(dtlsConn net.Conn, isServer bool) (*kcp.UDPSession, error) {
- pc := NewDtlsPacketConn(dtlsConn)
+func NewKCPOverDTLS(dtlsConn net.Conn, isServer bool) (_ *kcp.UDPSession, cleanup func() error, err error) {
+ var (
+ listener *kcp.Listener
+ sess *kcp.UDPSession
+ closeErr error
+ closeOnce sync.Once
+ )
+ transportCleanup := func() error {
+ closeOnce.Do(func() {
+ var errs []error
+ if sess != nil {
+ if err := sess.Close(); err != nil && !errors.Is(err, io.ErrClosedPipe) {
+ errs = append(errs, err)
+ }
+ }
+ if listener != nil {
+ if err := listener.Close(); err != nil && !errors.Is(err, io.ErrClosedPipe) {
+ errs = append(errs, err)
+ }
+ }
+ if err := dtlsConn.Close(); err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, io.ErrClosedPipe) {
+ errs = append(errs, err)
+ }
+ closeErr = errors.Join(errs...)
+ })
+ return closeErr
+ }
+ defer func() {
+ if err == nil {
+ return
+ }
+ if cleanupErr := transportCleanup(); cleanupErr != nil {
+ err = errors.Join(err, cleanupErr)
+ }
+ }()
block, err := kcp.NewNoneBlockCrypt(nil) // DTLS already encrypts
if err != nil {
- return nil, err
+ return nil, nil, err
}
- var sess *kcp.UDPSession
-
if isServer {
// Server: listen on the PacketConn and accept one session
- var listener *kcp.Listener
- listener, err = kcp.ServeConn(block, 0, 0, pc)
+ listener, err = kcp.ServeConn(block, 0, 0, &dtlsPacketConn{conn: dtlsConn})
if err != nil {
- return nil, err
+ return nil, nil, err
}
if err = listener.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
- return nil, err
+ return nil, nil, err
}
sess, err = listener.AcceptKCP()
if err != nil {
- return nil, err
+ return nil, nil, err
}
} else {
// Client: dial through the PacketConn
- sess, err = kcp.NewConn2(dtlsConn.RemoteAddr(), block, 0, 0, pc)
+ sess, err = kcp.NewConn2(dtlsConn.RemoteAddr(), block, 0, 0, &dtlsPacketConn{conn: dtlsConn})
if err != nil {
- return nil, err
+ return nil, nil, err
}
}
@@ -89,7 +123,7 @@ func NewKCPOverDTLS(dtlsConn net.Conn, isServer bool) (*kcp.UDPSession, error) {
sess.SetMtu(1200) // conservative MTU to fit inside DTLS+TURN
sess.SetACKNoDelay(true)
- return sess, nil
+ return sess, transportCleanup, nil
}
// DefaultSmuxConfig returns smux config tuned for TURN tunnel.