diff --git a/app/dns/dns.go b/app/dns/dns.go index 603640f1549f..5d6154f9fcc2 100644 --- a/app/dns/dns.go +++ b/app/dns/dns.go @@ -12,12 +12,15 @@ import ( "sync" "time" + router "github.com/xtls/xray-core/app/router" "github.com/xtls/xray-core/common" "github.com/xtls/xray-core/common/errors" "github.com/xtls/xray-core/common/net" + "github.com/xtls/xray-core/common/platform/filesystem" "github.com/xtls/xray-core/common/session" "github.com/xtls/xray-core/common/strmatcher" "github.com/xtls/xray-core/features/dns" + "google.golang.org/protobuf/proto" ) // DNS is a DNS rely server. @@ -97,6 +100,25 @@ func New(ctx context.Context, config *Config) (*DNS, error) { } for _, ns := range config.NameServer { + if runtime.GOOS != "windows" && runtime.GOOS != "wasm" { + err := parseDomains(ns) + if err != nil { + return nil, errors.New("failed to parse dns domain rules: ").Base(err) + } + + expectedGeoip, err := router.GetGeoIPList(ns.ExpectedGeoip) + if err != nil { + return nil, errors.New("failed to parse dns expectIPs rules: ").Base(err) + } + ns.ExpectedGeoip = expectedGeoip + + unexpectedGeoip, err := router.GetGeoIPList(ns.UnexpectedGeoip) + if err != nil { + return nil, errors.New("failed to parse dns unexpectedGeoip rules: ").Base(err) + } + ns.UnexpectedGeoip = unexpectedGeoip + + } domainRuleCount += len(ns.PrioritizedDomain) } @@ -580,3 +602,76 @@ func detectGUIPlatform() bool { } return false } + +func parseDomains(ns *NameServer) error { + pureDomains := []*router.Domain{} + + // convert to pure domain + for _, pd := range ns.PrioritizedDomain { + pureDomains = append(pureDomains, &router.Domain{ + Type: router.Domain_Type(pd.Type), + Value: pd.Domain, + }) + } + + domainList := []*router.Domain{} + for _, domain := range pureDomains { + val := strings.Split(domain.Value, "_") + if len(val) >= 2 { + + fileName := val[0] + code := val[1] + + bs, err := filesystem.ReadAsset(fileName) + if err != nil { + return errors.New("failed to load file: ", fileName).Base(err) + } + bs = filesystem.Find(bs, []byte(code)) + var geosite router.GeoSite + + if err := proto.Unmarshal(bs, &geosite); err != nil { + return errors.New("failed Unmarshal :").Base(err) + } + + // parse attr + if len(val) == 3 { + siteWithAttr := strings.Split(val[2], ",") + attrs := router.ParseAttrs(siteWithAttr) + if !attrs.IsEmpty() { + filteredDomains := make([]*router.Domain, 0, len(pureDomains)) + for _, domain := range geosite.Domain { + if attrs.Match(domain) { + filteredDomains = append(filteredDomains, domain) + } + } + geosite.Domain = filteredDomains + } + + } + + domainList = append(domainList, geosite.Domain...) + + // update ns.OriginalRules Size + ruleTag := strings.Join(val, ":") + for i, oRule := range ns.OriginalRules { + if oRule.Rule == strings.ToLower(ruleTag) { + ns.OriginalRules[i].Size = uint32(len(geosite.Domain)) + } + } + + } else { + domainList = append(domainList, domain) + } + } + + // convert back to NameServer_PriorityDomain + ns.PrioritizedDomain = []*NameServer_PriorityDomain{} + for _, pd := range domainList { + ns.PrioritizedDomain = append(ns.PrioritizedDomain, &NameServer_PriorityDomain{ + Type: ToDomainMatchingType(pd.Type), + Domain: pd.Value, + }) + } + + return nil +} diff --git a/app/dns/dns_test.go b/app/dns/dns_test.go index cb70b0b35e9c..d103704c3f14 100644 --- a/app/dns/dns_test.go +++ b/app/dns/dns_test.go @@ -541,7 +541,7 @@ func TestIPMatch(t *testing.T) { }, ExpectedGeoip: []*router.GeoIP{ { - CountryCode: "local", + // local Cidr: []*router.CIDR{ { // inner ip, will not match @@ -565,7 +565,7 @@ func TestIPMatch(t *testing.T) { }, ExpectedGeoip: []*router.GeoIP{ { - CountryCode: "test", + // test Cidr: []*router.CIDR{ { Ip: []byte{8, 8, 8, 8}, @@ -574,7 +574,7 @@ func TestIPMatch(t *testing.T) { }, }, { - CountryCode: "test", + // test Cidr: []*router.CIDR{ { Ip: []byte{8, 8, 8, 4}, @@ -669,7 +669,7 @@ func TestLocalDomain(t *testing.T) { }, ExpectedGeoip: []*router.GeoIP{ { // Will match localhost, localhost-a and localhost-b, - CountryCode: "local", + // local Cidr: []*router.CIDR{ {Ip: []byte{127, 0, 0, 2}, Prefix: 32}, {Ip: []byte{127, 0, 0, 3}, Prefix: 32}, diff --git a/app/dns/nameserver.go b/app/dns/nameserver.go index bad9277c7602..e57b74bfaa98 100644 --- a/app/dns/nameserver.go +++ b/app/dns/nameserver.go @@ -297,3 +297,18 @@ func ResolveIpOptionOverride(queryStrategy QueryStrategy, ipOption dns.IPOption) return ipOption } } + +func ToDomainMatchingType(t router.Domain_Type) DomainMatchingType { + switch t { + case router.Domain_Domain: + return DomainMatchingType_Subdomain + case router.Domain_Full: + return DomainMatchingType_Full + case router.Domain_Plain: + return DomainMatchingType_Keyword + case router.Domain_Regex: + return DomainMatchingType_Regex + default: + panic("unknown domain type") + } +} diff --git a/app/router/config.go b/app/router/config.go index 26833f4691d2..0b2ebd9404c0 100644 --- a/app/router/config.go +++ b/app/router/config.go @@ -79,7 +79,7 @@ func (rr *RoutingRule) BuildCondition() (Condition, error) { geoip := rr.Geoip if runtime.GOOS != "windows" && runtime.GOOS != "wasm" { var err error - geoip, err = getGeoIPList(rr.Geoip) + geoip, err = GetGeoIPList(rr.Geoip) if err != nil { return nil, errors.New("failed to build geoip from mmap").Base(err) } @@ -188,7 +188,7 @@ func (br *BalancingRule) Build(ohm outbound.Manager, dispatcher routing.Dispatch } } -func getGeoIPList(ips []*GeoIP) ([]*GeoIP, error) { +func GetGeoIPList(ips []*GeoIP) ([]*GeoIP, error) { geoipList := []*GeoIP{} for _, ip := range ips { if ip.CountryCode != "" { diff --git a/infra/conf/dns.go b/infra/conf/dns.go index a65f0ee84a0f..6ec307c66d02 100644 --- a/infra/conf/dns.go +++ b/infra/conf/dns.go @@ -80,21 +80,6 @@ func (c *NameServerConfig) UnmarshalJSON(data []byte) error { return errors.New("failed to parse name server: ", string(data)) } -func toDomainMatchingType(t router.Domain_Type) dns.DomainMatchingType { - switch t { - case router.Domain_Domain: - return dns.DomainMatchingType_Subdomain - case router.Domain_Full: - return dns.DomainMatchingType_Full - case router.Domain_Plain: - return dns.DomainMatchingType_Keyword - case router.Domain_Regex: - return dns.DomainMatchingType_Regex - default: - panic("unknown domain type") - } -} - func (c *NameServerConfig) Build() (*dns.NameServer, error) { if c.Address == nil { return nil, errors.New("NameServer address is not specified.") @@ -111,7 +96,7 @@ func (c *NameServerConfig) Build() (*dns.NameServer, error) { for _, pd := range parsedDomain { domains = append(domains, &dns.NameServer_PriorityDomain{ - Type: toDomainMatchingType(pd.Type), + Type: dns.ToDomainMatchingType(pd.Type), Domain: pd.Value, }) }