Skip to content

Commit 3f2b480

Browse files
committed
set match domain on update
1 parent 4e8ab0e commit 3f2b480

File tree

4 files changed

+82
-2
lines changed

4 files changed

+82
-2
lines changed

controllers/dns.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,11 @@ func createNs(w http.ResponseWriter, r *http.Request) {
9090
if gNs, ok := logic.GlobalNsList[req.Name]; ok {
9191
req.Servers = gNs.IPs
9292
}
93+
if !servercfg.IsPro {
94+
req.Tags = datatypes.JSONMap{
95+
"*": struct{}{},
96+
}
97+
}
9398
ns := schema.Nameserver{
9499
ID: uuid.New().String(),
95100
Name: req.Name,
@@ -222,6 +227,7 @@ func updateNs(w http.ResponseWriter, r *http.Request) {
222227
}
223228
ns.Servers = updateNs.Servers
224229
ns.Tags = updateNs.Tags
230+
ns.MatchDomain = updateNs.MatchDomain
225231
ns.Description = updateNs.Description
226232
ns.Name = updateNs.Name
227233
ns.Status = updateNs.Status

logic/dns.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,44 @@ func ValidateUpdateNameserverReq(updateNs schema.Nameserver) error {
394394
return nil
395395
}
396396

397+
func GetNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
398+
ns := &schema.Nameserver{
399+
NetworkID: node.Network,
400+
}
401+
nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
402+
for _, nsI := range nsLi {
403+
if !nsI.Status {
404+
continue
405+
}
406+
_, all := nsI.Tags["*"]
407+
if all {
408+
returnNsLi = append(returnNsLi, models.Nameserver{
409+
IPs: nsI.Servers,
410+
MatchDomain: nsI.MatchDomain,
411+
})
412+
continue
413+
}
414+
for tagI := range node.Tags {
415+
if _, ok := nsI.Tags[tagI.String()]; ok {
416+
returnNsLi = append(returnNsLi, models.Nameserver{
417+
IPs: nsI.Servers,
418+
MatchDomain: nsI.MatchDomain,
419+
})
420+
}
421+
}
422+
}
423+
if node.IsInternetGateway {
424+
globalNs := models.Nameserver{
425+
MatchDomain: ".",
426+
}
427+
for _, nsI := range GlobalNsList {
428+
globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
429+
}
430+
returnNsLi = append(returnNsLi, globalNs)
431+
}
432+
return
433+
}
434+
397435
func GetNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
398436
if h.DNS != "yes" {
399437
return

pro/controllers/users.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1328,7 +1328,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
13281328
Addresses: utils.NoEmptyStringToCsv(node.Address.String(), node.Address6.String()),
13291329
}
13301330
if !node.IsInternetGateway {
1331-
hNs := logic.GetNameserversForHost(host)
1331+
hNs := logic.GetNameserversForNode(&node)
13321332
for _, nsI := range hNs {
13331333
gw.MatchDomains = append(gw.MatchDomains, nsI.MatchDomain)
13341334
}
@@ -1379,7 +1379,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
13791379
Addresses: utils.NoEmptyStringToCsv(node.Address.String(), node.Address6.String()),
13801380
}
13811381
if !node.IsInternetGateway {
1382-
hNs := logic.GetNameserversForHost(host)
1382+
hNs := logic.GetNameserversForNode(&node)
13831383
for _, nsI := range hNs {
13841384
gw.MatchDomains = append(gw.MatchDomains, nsI.MatchDomain)
13851385
}

pro/logic/dns.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package logic
2+
3+
import (
4+
"errors"
5+
6+
"github.com/gravitl/netmaker/logic"
7+
"github.com/gravitl/netmaker/models"
8+
"github.com/gravitl/netmaker/schema"
9+
)
10+
11+
func ValidateNameserverReq(ns schema.Nameserver) error {
12+
if ns.Name == "" {
13+
return errors.New("name is required")
14+
}
15+
if ns.NetworkID == "" {
16+
return errors.New("network is required")
17+
}
18+
if len(ns.Servers) == 0 {
19+
return errors.New("atleast one nameserver should be specified")
20+
}
21+
if !logic.IsValidMatchDomain(ns.MatchDomain) {
22+
return errors.New("invalid match domain")
23+
}
24+
if len(ns.Tags) > 0 {
25+
for tagI := range ns.Tags {
26+
if tagI == "*" {
27+
continue
28+
}
29+
_, err := GetTag(models.TagID(tagI))
30+
if err != nil {
31+
return errors.New("invalid tag")
32+
}
33+
}
34+
}
35+
return nil
36+
}

0 commit comments

Comments
 (0)