Skip to content

Commit 2cd727a

Browse files
committed
Support routing
1 parent e42cd0a commit 2cd727a

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

config.go

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
package awgproxy
22

33
import (
4+
"bufio"
45
"encoding/base64"
56
"encoding/hex"
67
"errors"
8+
"io"
9+
"log"
710
"net"
811
"os"
912
"strings"
1013

1114
"github.com/go-ini/ini"
1215

16+
"net/http"
1317
"net/netip"
1418
)
1519

@@ -68,6 +72,7 @@ type HTTPConfig struct {
6872
Password string
6973
CertFile string
7074
KeyFile string
75+
RouteHosts map[string]bool
7176
}
7277

7378
type Configuration struct {
@@ -531,6 +536,43 @@ func parseHTTPConfig(section *ini.Section) (RoutineSpawner, error) {
531536
keyFile, _ := parseString(section, "KeyFile")
532537
config.KeyFile = keyFile
533538

539+
config.RouteHosts = make(map[string]bool)
540+
541+
domainsUrl, err := parseString(section, "DomainsUrl")
542+
if err == nil && domainsUrl != "" {
543+
resp, err := http.Get(domainsUrl)
544+
if err != nil {
545+
return nil, err
546+
}
547+
548+
rd := bufio.NewReader(resp.Body)
549+
for {
550+
str, err := rd.ReadString('\n')
551+
if err == io.EOF {
552+
break
553+
} else if err != nil {
554+
log.Printf("erra %s", err)
555+
return nil, err
556+
}
557+
558+
config.RouteHosts[strings.Trim(str, " \n\r")] = true
559+
}
560+
}
561+
562+
includeStr, err := parseString(section, "RouteInclude")
563+
if err == nil {
564+
for _, str := range strings.Split(includeStr, ",") {
565+
config.RouteHosts[strings.Trim(str, " \n\r")] = true
566+
}
567+
}
568+
569+
excludeStr, err := parseString(section, "RouteExclude")
570+
if err == nil {
571+
for _, str := range strings.Split(excludeStr, ",") {
572+
delete(config.RouteHosts, strings.Trim(str, " \n\r"))
573+
}
574+
}
575+
534576
return config, nil
535577
}
536578

http.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"log"
1111
"net"
1212
"net/http"
13+
"regexp"
1314
"strings"
1415
)
1516

@@ -23,6 +24,25 @@ type HTTPServer struct {
2324

2425
authRequired bool
2526
tlsRequired bool
27+
routeRegex *regexp.Regexp
28+
}
29+
30+
func (s *HTTPServer) needRoute(req *http.Request) bool {
31+
hostname := req.URL.Hostname()
32+
if s.routeRegex == nil {
33+
s.routeRegex = regexp.MustCompile(`^([^\.]*)\.`)
34+
}
35+
36+
for strings.Count(hostname, ".") > 0 {
37+
if s.config.RouteHosts[hostname] {
38+
log.Printf("route %s", hostname)
39+
return true
40+
}
41+
42+
hostname = s.routeRegex.ReplaceAllString(hostname, "")
43+
}
44+
45+
return false
2646
}
2747

2848
func (s *HTTPServer) authenticate(req *http.Request) (int, error) {
@@ -57,7 +77,11 @@ func (s *HTTPServer) handleConn(req *http.Request, conn net.Conn) (peer net.Conn
5777
addr = net.JoinHostPort(addr, port)
5878
}
5979

60-
peer, err = s.dial("tcp", addr)
80+
if s.needRoute(req) {
81+
peer, err = s.dial("tcp", addr)
82+
} else {
83+
peer, err = net.Dial("tcp", addr)
84+
}
6185
if err != nil {
6286
return peer, fmt.Errorf("tun tcp dial failed: %w", err)
6387
}
@@ -78,7 +102,12 @@ func (s *HTTPServer) handle(req *http.Request) (peer net.Conn, err error) {
78102
addr = net.JoinHostPort(addr, port)
79103
}
80104

81-
peer, err = s.dial("tcp", addr)
105+
if s.needRoute(req) {
106+
peer, err = s.dial("tcp", addr)
107+
} else {
108+
peer, err = net.Dial("tcp", addr)
109+
}
110+
82111
if err != nil {
83112
return peer, fmt.Errorf("tun tcp dial failed: %w", err)
84113
}

0 commit comments

Comments
 (0)