Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions app/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ import (
"sync"
"time"

router "github.com/xtls/xray-core/app/router"
"github.com/xtls/xray-core/common"
"github.com/xtls/xray-core/common/errors"
"github.com/xtls/xray-core/common/net"
"github.com/xtls/xray-core/common/platform/filesystem"
"github.com/xtls/xray-core/common/session"
"github.com/xtls/xray-core/common/strmatcher"
"github.com/xtls/xray-core/features/dns"
"google.golang.org/protobuf/proto"
)

// DNS is a DNS rely server.
Expand Down Expand Up @@ -97,6 +100,25 @@ func New(ctx context.Context, config *Config) (*DNS, error) {
}

for _, ns := range config.NameServer {
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
err := parseDomains(ns)
if err != nil {
return nil, errors.New("failed to parse dns domain rules: ").Base(err)
}

expectedGeoip, err := router.GetGeoIPList(ns.ExpectedGeoip)
if err != nil {
return nil, errors.New("failed to parse dns expectIPs rules: ").Base(err)
}
ns.ExpectedGeoip = expectedGeoip

unexpectedGeoip, err := router.GetGeoIPList(ns.UnexpectedGeoip)
if err != nil {
return nil, errors.New("failed to parse dns unexpectedGeoip rules: ").Base(err)
}
ns.UnexpectedGeoip = unexpectedGeoip

}
domainRuleCount += len(ns.PrioritizedDomain)
}

Expand Down Expand Up @@ -580,3 +602,76 @@ func detectGUIPlatform() bool {
}
return false
}

func parseDomains(ns *NameServer) error {
pureDomains := []*router.Domain{}

// convert to pure domain
for _, pd := range ns.PrioritizedDomain {
pureDomains = append(pureDomains, &router.Domain{
Type: router.Domain_Type(pd.Type),
Value: pd.Domain,
})
}

domainList := []*router.Domain{}
for _, domain := range pureDomains {
val := strings.Split(domain.Value, "_")
if len(val) >= 2 {

fileName := val[0]
code := val[1]

bs, err := filesystem.ReadAsset(fileName)
if err != nil {
return errors.New("failed to load file: ", fileName).Base(err)
}
bs = filesystem.Find(bs, []byte(code))
var geosite router.GeoSite

if err := proto.Unmarshal(bs, &geosite); err != nil {
return errors.New("failed Unmarshal :").Base(err)
}

// parse attr
if len(val) == 3 {
siteWithAttr := strings.Split(val[2], ",")
attrs := router.ParseAttrs(siteWithAttr)
if !attrs.IsEmpty() {
filteredDomains := make([]*router.Domain, 0, len(pureDomains))
for _, domain := range geosite.Domain {
if attrs.Match(domain) {
filteredDomains = append(filteredDomains, domain)
}
}
geosite.Domain = filteredDomains
}

}

domainList = append(domainList, geosite.Domain...)

// update ns.OriginalRules Size
ruleTag := strings.Join(val, ":")
for i, oRule := range ns.OriginalRules {
if oRule.Rule == strings.ToLower(ruleTag) {
ns.OriginalRules[i].Size = uint32(len(geosite.Domain))
}
}

} else {
domainList = append(domainList, domain)
}
}

// convert back to NameServer_PriorityDomain
ns.PrioritizedDomain = []*NameServer_PriorityDomain{}
for _, pd := range domainList {
ns.PrioritizedDomain = append(ns.PrioritizedDomain, &NameServer_PriorityDomain{
Type: ToDomainMatchingType(pd.Type),
Domain: pd.Value,
})
}

return nil
}
8 changes: 4 additions & 4 deletions app/dns/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -541,7 +541,7 @@ func TestIPMatch(t *testing.T) {
},
ExpectedGeoip: []*router.GeoIP{
{
CountryCode: "local",
// local
Cidr: []*router.CIDR{
{
// inner ip, will not match
Expand All @@ -565,7 +565,7 @@ func TestIPMatch(t *testing.T) {
},
ExpectedGeoip: []*router.GeoIP{
{
CountryCode: "test",
// test
Cidr: []*router.CIDR{
{
Ip: []byte{8, 8, 8, 8},
Expand All @@ -574,7 +574,7 @@ func TestIPMatch(t *testing.T) {
},
},
{
CountryCode: "test",
// test
Cidr: []*router.CIDR{
{
Ip: []byte{8, 8, 8, 4},
Expand Down Expand Up @@ -669,7 +669,7 @@ func TestLocalDomain(t *testing.T) {
},
ExpectedGeoip: []*router.GeoIP{
{ // Will match localhost, localhost-a and localhost-b,
CountryCode: "local",
// local
Cidr: []*router.CIDR{
{Ip: []byte{127, 0, 0, 2}, Prefix: 32},
{Ip: []byte{127, 0, 0, 3}, Prefix: 32},
Expand Down
15 changes: 15 additions & 0 deletions app/dns/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,18 @@ func ResolveIpOptionOverride(queryStrategy QueryStrategy, ipOption dns.IPOption)
return ipOption
}
}

func ToDomainMatchingType(t router.Domain_Type) DomainMatchingType {
switch t {
case router.Domain_Domain:
return DomainMatchingType_Subdomain
case router.Domain_Full:
return DomainMatchingType_Full
case router.Domain_Plain:
return DomainMatchingType_Keyword
case router.Domain_Regex:
return DomainMatchingType_Regex
default:
panic("unknown domain type")
}
}
4 changes: 2 additions & 2 deletions app/router/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) {
geoip := rr.Geoip
if runtime.GOOS != "windows" && runtime.GOOS != "wasm" {
var err error
geoip, err = getGeoIPList(rr.Geoip)
geoip, err = GetGeoIPList(rr.Geoip)
if err != nil {
return nil, errors.New("failed to build geoip from mmap").Base(err)
}
Expand Down Expand Up @@ -188,7 +188,7 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch
}
}

func getGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) {
geoipList := []*GeoIP{}
for _, ip := range ips {
if ip.CountryCode != "" {
Expand Down
17 changes: 1 addition & 16 deletions infra/conf/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,21 +80,6 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error {
return errors.New("failed to parse name server: ", string(data))
}

func toDomainMatchingType(t router.Domain_Type) dns.DomainMatchingType {
switch t {
case router.Domain_Domain:
return dns.DomainMatchingType_Subdomain
case router.Domain_Full:
return dns.DomainMatchingType_Full
case router.Domain_Plain:
return dns.DomainMatchingType_Keyword
case router.Domain_Regex:
return dns.DomainMatchingType_Regex
default:
panic("unknown domain type")
}
}

func (c *NameServerConfig) Build() (*dns.NameServer, error) {
if c.Address == nil {
return nil, errors.New("NameServer address is not specified.")
Expand All @@ -111,7 +96,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) {

for _, pd := range parsedDomain {
domains = append(domains, &dns.NameServer_PriorityDomain{
Type: toDomainMatchingType(pd.Type),
Type: dns.ToDomainMatchingType(pd.Type),
Domain: pd.Value,
})
}
Expand Down