Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 26 additions & 55 deletions app/router/condition_geoip_test.go
Original file line number Diff line number Diff line change
@@ -1,40 +1,17 @@
package router_test

import (
"fmt"
"os"
"path/filepath"
"runtime"
"testing"

"github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform"
"github.com/xtls/xray-core/common/platform/filesystem"
"google.golang.org/protobuf/proto"
"github.com/xtls/xray-core/infra/conf"
)

func getAssetPath(file string) (string, error) {
path := platform.GetAssetLocation(file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
path := filepath.Join("..", "..", "resources", file)
_, err := os.Stat(path)
if os.IsNotExist(err) {
return "", fmt.Errorf("can't find %s in standard asset locations or {project_root}/resources", file)
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}
return path, nil
}
if err != nil {
return "", fmt.Errorf("can't stat %s: %v", path, err)
}

return path, nil
}

func TestGeoIPMatcher(t *testing.T) {
cidrList := []*router.CIDR{
{Ip: []byte{0, 0, 0, 0}, Prefix: 8},
Expand Down Expand Up @@ -182,12 +159,11 @@ func TestGeoIPReverseMatcher(t *testing.T) {
}

func TestGeoIPMatcher4CN(t *testing.T) {
ips, err := loadGeoIP("CN")
geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
common.Must(err)

matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
common.Must(err)

if matcher.Match([]byte{8, 8, 8, 8}) {
Expand All @@ -196,50 +172,46 @@ func TestGeoIPMatcher4CN(t *testing.T) {
}

func TestGeoIPMatcher6US(t *testing.T) {
ips, err := loadGeoIP("US")
geo := "geoip:us"
geoip, err := loadGeoIP(geo)
common.Must(err)

matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
common.Must(err)

if !matcher.Match(net.ParseAddress("2001:4860:4860::8888").IP()) {
t.Error("expect US geoip contain 2001:4860:4860::8888, but actually not")
}
}

func loadGeoIP(country string) ([]*router.CIDR, error) {
path, err := getAssetPath("geoip.dat")
if err != nil {
return nil, err
}
geoipBytes, err := filesystem.ReadFile(path)
func loadGeoIP(geo string) (*router.GeoIP, error) {
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))

geoip, err := conf.ToCidrList([]string{geo})
if err != nil {
return nil, err
}

var geoipList router.GeoIPList
if err := proto.Unmarshal(geoipBytes, &geoipList); err != nil {
return nil, err
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
geoip, err = router.GetGeoIPList(geoip)
if err != nil {
return nil, err
}
}

for _, geoip := range geoipList.Entry {
if geoip.CountryCode == country {
return geoip.Cidr, nil
}
if len(geoip) == 0 {
panic("country not found: " + geo)
}

panic("country not found: " + country)
return geoip[0], nil
}

func BenchmarkGeoIPMatcher4CN(b *testing.B) {
ips, err := loadGeoIP("CN")
geo := "geoip:cn"
geoip, err := loadGeoIP(geo)
common.Must(err)

matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
common.Must(err)

b.ResetTimer()
Expand All @@ -250,12 +222,11 @@ func BenchmarkGeoIPMatcher4CN(b *testing.B) {
}

func BenchmarkGeoIPMatcher6US(b *testing.B) {
ips, err := loadGeoIP("US")
geo := "geoip:us"
geoip, err := loadGeoIP(geo)
common.Must(err)

matcher, err := router.BuildOptimizedGeoIPMatcher(&router.GeoIP{
Cidr: ips,
})
matcher, err := router.BuildOptimizedGeoIPMatcher(geoip)
common.Must(err)

b.ResetTimer()
Expand Down
93 changes: 65 additions & 28 deletions app/router/condition_test.go
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
package router_test

import (
"os"
"path/filepath"
"runtime"
"strconv"
"testing"

"github.com/xtls/xray-core/app/router"
. "github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/protocol"
"github.com/xtls/xray-core/common/protocol/http"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/features/routing"
routing_session "github.com/xtls/xray-core/features/routing/session"
"google.golang.org/protobuf/proto"
"github.com/xtls/xray-core/infra/conf"
)

func withBackground() routing.Context {
Expand Down Expand Up @@ -300,32 +302,25 @@ func TestRoutingRule(t *testing.T) {
}
}

func loadGeoSite(country string) ([]*Domain, error) {
path, err := getAssetPath("geosite.dat")
if err != nil {
return nil, err
}
geositeBytes, err := filesystem.ReadFile(path)
if err != nil {
return nil, err
}
func loadGeoSiteDomains(geo string) ([]*Domain, error) {
os.Setenv("XRAY_LOCATION_ASSET", filepath.Join("..", "..", "resources"))

var geositeList GeoSiteList
if err := proto.Unmarshal(geositeBytes, &geositeList); err != nil {
domains, err := conf.ParseDomainRule(geo)
if err != nil {
return nil, err
}

for _, site := range geositeList.Entry {
if site.CountryCode == country {
return site.Domain, nil
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
domains, err = router.GetDomainList(domains)
if err != nil {
return nil, err
}
}

return nil, errors.New("country not found: " + country)
return domains, nil
}

func TestChinaSites(t *testing.T) {
domains, err := loadGeoSite("CN")
domains, err := loadGeoSiteDomains("geosite:cn")
common.Must(err)

acMatcher, err := NewMphMatcherGroup(domains)
Expand Down Expand Up @@ -366,8 +361,50 @@ func TestChinaSites(t *testing.T) {
}
}

func TestChinaSitesWithAttrs(t *testing.T) {
domains, err := loadGeoSiteDomains("geosite:google@cn")
common.Must(err)

acMatcher, err := NewMphMatcherGroup(domains)
common.Must(err)

type TestCase struct {
Domain string
Output bool
}
testCases := []TestCase{
{
Domain: "google.cn",
Output: true,
},
{
Domain: "recaptcha.net",
Output: true,
},
{
Domain: "164.com",
Output: false,
},
{
Domain: "164.com",
Output: false,
},
}

for i := 0; i < 1024; i++ {
testCases = append(testCases, TestCase{Domain: strconv.Itoa(i) + ".not-exists.com", Output: false})
}

for _, testCase := range testCases {
r := acMatcher.ApplyDomain(testCase.Domain)
if r != testCase.Output {
t.Error("ACDomainMatcher expected output ", testCase.Output, " for domain ", testCase.Domain, " but got ", r)
}
}
}

func BenchmarkMphDomainMatcher(b *testing.B) {
domains, err := loadGeoSite("CN")
domains, err := loadGeoSiteDomains("geosite:cn")
common.Must(err)

matcher, err := NewMphMatcherGroup(domains)
Expand Down Expand Up @@ -412,11 +449,11 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
var geoips []*GeoIP

{
ips, err := loadGeoIP("CN")
ips, err := loadGeoIP("geoip:cn")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "CN",
Cidr: ips,
Cidr: ips.Cidr,
})
}

Expand All @@ -425,25 +462,25 @@ func BenchmarkMultiGeoIPMatcher(b *testing.B) {
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "JP",
Cidr: ips,
Cidr: ips.Cidr,
})
}

{
ips, err := loadGeoIP("CA")
ips, err := loadGeoIP("geoip:ca")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "CA",
Cidr: ips,
Cidr: ips.Cidr,
})
}

{
ips, err := loadGeoIP("US")
ips, err := loadGeoIP("geoip:us")
common.Must(err)
geoips = append(geoips, &GeoIP{
CountryCode: "US",
Cidr: ips,
Cidr: ips.Cidr,
})
}

Expand Down
6 changes: 3 additions & 3 deletions app/router/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
domains := rr.Domain
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
var err error
domains, err = getDomainList(rr.Domain)
domains, err = GetDomainList(rr.Domain)
if err != nil {
return nil, errors.New("failed to build domains from mmap").Base(err)
}
Expand All @@ -122,7 +122,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
if err != nil {
return nil, errors.New("failed to build domain condition with MphDomainMatcher").Base(err)
}
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(rr.Domain), " domain rule(s)")
errors.LogDebug(context.Background(), "MphDomainMatcher is enabled for ", len(domains), " domain rule(s)")
conds.Add(matcher)
}

Expand Down Expand Up @@ -218,7 +218,7 @@ func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {

}

func getDomainList(domains []*Domain) ([]*Domain, error) {
func GetDomainList(domains []*Domain) ([]*Domain, error) {
domainList := []*Domain{}
for _, domain := range domains {
val := strings.Split(domain.Value, "_")
Expand Down
5 changes: 4 additions & 1 deletion common/platform/windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

package platform

import "path/filepath"
import (
"path/filepath"
)

func LineSeparator() string {
return "\r\n"
Expand All @@ -12,6 +14,7 @@ func LineSeparator() string {
// GetAssetLocation searches for `file` in the env dir and the executable dir
func GetAssetLocation(file string) string {
assetPath := NewEnvFlag(AssetLocation).GetValue(getExecutableDir)

return filepath.Join(assetPath, file)
}

Expand Down
2 changes: 1 addition & 1 deletion infra/conf/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) {
var originalRules []*dns.NameServer_OriginalRule

for _, rule := range c.Domains {
parsedDomain, err := parseDomainRule(rule)
parsedDomain, err := ParseDomainRule(rule)
if err != nil {
return nil, errors.New("invalid domain rule: ", rule).Base(err)
}
Expand Down
6 changes: 3 additions & 3 deletions infra/conf/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ func loadGeositeWithAttr(file string, siteWithAttr string) ([]*router.Domain, er
return filteredDomains, nil
}

func parseDomainRule(domain string) ([]*router.Domain, error) {
func ParseDomainRule(domain string) ([]*router.Domain, error) {
if strings.HasPrefix(domain, "geosite:") {
country := strings.ToUpper(domain[8:])
domains, err := loadGeositeWithAttr("geosite.dat", country)
Expand Down Expand Up @@ -489,7 +489,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {

if rawFieldRule.Domain != nil {
for _, domain := range *rawFieldRule.Domain {
rules, err := parseDomainRule(domain)
rules, err := ParseDomainRule(domain)
if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
}
Expand All @@ -499,7 +499,7 @@ func parseFieldRule(msg json.RawMessage) (*router.RoutingRule, error) {

if rawFieldRule.Domains != nil {
for _, domain := range *rawFieldRule.Domains {
rules, err := parseDomainRule(domain)
rules, err := ParseDomainRule(domain)
if err != nil {
return nil, errors.New("failed to parse domain rule: ", domain).Base(err)
}
Expand Down