Skip to content

Commit 0b008d8

Browse files
authored
feat: refactor storage (#158)
Fixes #76
1 parent 02a282d commit 0b008d8

File tree

3 files changed

+100
-65
lines changed

3 files changed

+100
-65
lines changed

internal/dns/dns.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023 Blink Labs Software
1+
// Copyright 2024 Blink Labs Software
22
//
33
// Use of this source code is governed by an MIT-style
44
// license that can be found in the LICENSE file or at
@@ -294,15 +294,21 @@ func findNameserversForDomain(
294294
lookupDomainName := strings.Join(queryLabels[startLabelIdx:], ".")
295295
// Convert to canonical form for consistency
296296
lookupDomainName = dns.CanonicalName(lookupDomainName)
297-
nameservers, err := state.GetState().LookupDomain(lookupDomainName)
297+
nsRecords, err := state.GetState().LookupRecords([]string{"NS"}, lookupDomainName)
298298
if err != nil {
299299
return "", nil, err
300300
}
301-
if nameservers != nil {
301+
if len(nsRecords) > 0 {
302302
ret := map[string][]net.IP{}
303-
for k, v := range nameservers {
304-
k = dns.Fqdn(k)
305-
ret[k] = append(ret[k], net.ParseIP(v))
303+
for _, nsRecord := range nsRecords {
304+
// Get matching A/AAAA records for NS entry
305+
aRecords, err := state.GetState().LookupRecords([]string{"A", "AAAA"}, nsRecord.Rhs)
306+
if err != nil {
307+
return "", nil, err
308+
}
309+
for _, aRecord := range aRecords {
310+
ret[nsRecord.Rhs] = append(ret[nsRecord.Rhs], net.ParseIP(aRecord.Rhs))
311+
}
306312
}
307313
return dns.Fqdn(lookupDomainName), ret, nil
308314
}

internal/indexer/indexer.go

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023 Blink Labs Software
1+
// Copyright 2024 Blink Labs Software
22
//
33
// Use of this source code is governed by an MIT-style
44
// license that can be found in the LICENSE file or at
@@ -234,23 +234,20 @@ func (i *Indexer) handleEvent(evt event.Event) error {
234234
continue
235235
}
236236
}
237-
nameServers := map[string]string{}
237+
// Convert domain records into our storage format
238+
tmpRecords := []state.DomainRecord{}
238239
for _, record := range dnsDomain.Records {
239-
recordName := strings.Trim(
240-
string(record.Lhs),
241-
`.`,
242-
)
243-
// NOTE: we're losing information here, but we need to revamp the storage
244-
// format before we can use it. We're also making the assumption that all
245-
// records are for nameservers
246-
switch strings.ToUpper(string(record.Type)) {
247-
case "A", "AAAA":
248-
nameServers[recordName] = string(record.Rhs)
249-
default:
250-
continue
240+
tmpRecord := state.DomainRecord{
241+
Lhs: string(record.Lhs),
242+
Type: string(record.Type),
243+
Rhs: string(record.Rhs),
244+
}
245+
if record.Ttl.HasValue() {
246+
tmpRecord.Ttl = int(record.Ttl.Value)
251247
}
248+
tmpRecords = append(tmpRecords, tmpRecord)
252249
}
253-
if err := state.GetState().UpdateDomain(domainName, nameServers); err != nil {
250+
if err := state.GetState().UpdateDomain(domainName, tmpRecords); err != nil {
254251
return err
255252
}
256253
logger.Infof(

internal/state/state.go

Lines changed: 76 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright 2023 Blink Labs Software
1+
// Copyright 2024 Blink Labs Software
22
//
33
// Use of this source code is governed by an MIT-style
44
// license that can be found in the LICENSE file or at
@@ -7,8 +7,11 @@
77
package state
88

99
import (
10+
"bytes"
11+
"encoding/gob"
1012
"errors"
1113
"fmt"
14+
"slices"
1215
"strconv"
1316
"strings"
1417
"time"
@@ -28,6 +31,13 @@ type State struct {
2831
gcTimer *time.Ticker
2932
}
3033

34+
type DomainRecord struct {
35+
Lhs string
36+
Type string
37+
Ttl int
38+
Rhs string
39+
}
40+
3141
var globalState = &State{}
3242

3343
func (s *State) Load() error {
@@ -151,69 +161,91 @@ func (s *State) GetCursor() (uint64, string, error) {
151161

152162
func (s *State) UpdateDomain(
153163
domainName string,
154-
nameServers map[string]string,
164+
records []DomainRecord,
155165
) error {
156166
logger := logging.GetLogger()
157167
err := s.db.Update(func(txn *badger.Txn) error {
158-
// Delete old records for domain
159-
keyPrefix := []byte(fmt.Sprintf("domain_%s_", domainName))
160-
it := txn.NewIterator(badger.DefaultIteratorOptions)
161-
defer it.Close()
162-
for it.Seek(keyPrefix); it.ValidForPrefix(keyPrefix); it.Next() {
163-
item := it.Item()
164-
k := item.Key()
165-
if err := txn.Delete(k); err != nil {
166-
return err
167-
}
168-
logger.Debug(
169-
fmt.Sprintf(
170-
"deleted record for domain %s with key: %s",
171-
domainName,
172-
k,
173-
),
174-
)
175-
}
176168
// Add new records
177-
for nameServer, ipAddress := range nameServers {
169+
recordKeys := make([]string, 0)
170+
for recordIdx, record := range records {
178171
key := fmt.Sprintf(
179-
"domain_%s_nameserver_%s",
180-
domainName,
181-
nameServer,
172+
"r_%s_%s_%d",
173+
strings.ToUpper(record.Type),
174+
strings.Trim(record.Lhs, `.`),
175+
recordIdx,
182176
)
183-
if err := txn.Set([]byte(key), []byte(ipAddress)); err != nil {
177+
recordKeys = append(recordKeys, key)
178+
var gobBuf bytes.Buffer
179+
gobEnc := gob.NewEncoder(&gobBuf)
180+
if err := gobEnc.Encode(&record); err != nil {
181+
return err
182+
}
183+
recordVal := gobBuf.Bytes()[:]
184+
if err := txn.Set([]byte(key), recordVal); err != nil {
184185
return err
185186
}
186187
logger.Debug(
187188
fmt.Sprintf(
188-
"added record for domain %s: %s: %s",
189+
"added record for domain %s: %s: %s: %s",
189190
domainName,
190-
nameServer,
191-
ipAddress,
191+
record.Type,
192+
record.Lhs,
193+
record.Rhs,
192194
),
193195
)
194196
}
197+
// Delete old records in tracking key that are no longer present after this update
198+
domainRecordsKey := []byte(fmt.Sprintf("d_%s_records", domainName))
199+
domainRecordsItem, err := txn.Get(domainRecordsKey)
200+
if err != nil {
201+
if !errors.Is(err, badger.ErrKeyNotFound) {
202+
return err
203+
}
204+
} else {
205+
domainRecordsVal, err := domainRecordsItem.ValueCopy(nil)
206+
if err != nil {
207+
return err
208+
}
209+
domainRecordsSplit := strings.Split(string(domainRecordsVal), ",")
210+
for _, tmpRecordKey := range domainRecordsSplit {
211+
if !slices.Contains(recordKeys, tmpRecordKey) {
212+
if err := txn.Delete([]byte(tmpRecordKey)); err != nil {
213+
return err
214+
}
215+
}
216+
}
217+
}
218+
// Update tracking key with new record keys
219+
recordKeysJoin := strings.Join(recordKeys, ",")
220+
if err := txn.Set(domainRecordsKey, []byte(recordKeysJoin)); err != nil {
221+
return err
222+
}
195223
return nil
196224
})
197225
return err
198226
}
199227

200-
func (s *State) LookupDomain(domainName string) (map[string]string, error) {
201-
ret := map[string]string{}
202-
keyPrefix := []byte(fmt.Sprintf("domain_%s_nameserver_", domainName))
228+
func (s *State) LookupRecords(recordTypes []string, recordName string) ([]DomainRecord, error) {
229+
ret := []DomainRecord{}
230+
recordName = strings.Trim(recordName, `.`)
203231
err := s.db.View(func(txn *badger.Txn) error {
204-
it := txn.NewIterator(badger.DefaultIteratorOptions)
205-
defer it.Close()
206-
for it.Seek(keyPrefix); it.ValidForPrefix(keyPrefix); it.Next() {
207-
item := it.Item()
208-
k := item.Key()
209-
keyParts := strings.Split(string(k), "_")
210-
nameServer := keyParts[len(keyParts)-1]
211-
err := item.Value(func(v []byte) error {
212-
ret[nameServer] = string(v)
213-
return nil
214-
})
215-
if err != nil {
216-
return err
232+
for _, recordType := range recordTypes {
233+
keyPrefix := []byte(fmt.Sprintf("r_%s_%s_", strings.ToUpper(recordType), recordName))
234+
it := txn.NewIterator(badger.DefaultIteratorOptions)
235+
defer it.Close()
236+
for it.Seek(keyPrefix); it.ValidForPrefix(keyPrefix); it.Next() {
237+
item := it.Item()
238+
val, err := item.ValueCopy(nil)
239+
if err != nil {
240+
return err
241+
}
242+
gobBuf := bytes.NewReader(val)
243+
gobDec := gob.NewDecoder(gobBuf)
244+
var tmpRecord DomainRecord
245+
if err := gobDec.Decode(&tmpRecord); err != nil {
246+
return err
247+
}
248+
ret = append(ret, tmpRecord)
217249
}
218250
}
219251
return nil

0 commit comments

Comments
 (0)