1
- // Copyright 2023 Blink Labs Software
1
+ // Copyright 2024 Blink Labs Software
2
2
//
3
3
// Use of this source code is governed by an MIT-style
4
4
// license that can be found in the LICENSE file or at
7
7
package state
8
8
9
9
import (
10
+ "bytes"
11
+ "encoding/gob"
10
12
"errors"
11
13
"fmt"
14
+ "slices"
12
15
"strconv"
13
16
"strings"
14
17
"time"
@@ -28,6 +31,13 @@ type State struct {
28
31
gcTimer * time.Ticker
29
32
}
30
33
34
+ type DomainRecord struct {
35
+ Lhs string
36
+ Type string
37
+ Ttl int
38
+ Rhs string
39
+ }
40
+
31
41
var globalState = & State {}
32
42
33
43
func (s * State ) Load () error {
@@ -151,69 +161,91 @@ func (s *State) GetCursor() (uint64, string, error) {
151
161
152
162
func (s * State ) UpdateDomain (
153
163
domainName string ,
154
- nameServers map [ string ] string ,
164
+ records [] DomainRecord ,
155
165
) error {
156
166
logger := logging .GetLogger ()
157
167
err := s .db .Update (func (txn * badger.Txn ) error {
158
- // Delete old records for domain
159
- keyPrefix := []byte (fmt .Sprintf ("domain_%s_" , domainName ))
160
- it := txn .NewIterator (badger .DefaultIteratorOptions )
161
- defer it .Close ()
162
- for it .Seek (keyPrefix ); it .ValidForPrefix (keyPrefix ); it .Next () {
163
- item := it .Item ()
164
- k := item .Key ()
165
- if err := txn .Delete (k ); err != nil {
166
- return err
167
- }
168
- logger .Debug (
169
- fmt .Sprintf (
170
- "deleted record for domain %s with key: %s" ,
171
- domainName ,
172
- k ,
173
- ),
174
- )
175
- }
176
168
// Add new records
177
- for nameServer , ipAddress := range nameServers {
169
+ recordKeys := make ([]string , 0 )
170
+ for recordIdx , record := range records {
178
171
key := fmt .Sprintf (
179
- "domain_%s_nameserver_%s" ,
180
- domainName ,
181
- nameServer ,
172
+ "r_%s_%s_%d" ,
173
+ strings .ToUpper (record .Type ),
174
+ strings .Trim (record .Lhs , `.` ),
175
+ recordIdx ,
182
176
)
183
- if err := txn .Set ([]byte (key ), []byte (ipAddress )); err != nil {
177
+ recordKeys = append (recordKeys , key )
178
+ var gobBuf bytes.Buffer
179
+ gobEnc := gob .NewEncoder (& gobBuf )
180
+ if err := gobEnc .Encode (& record ); err != nil {
181
+ return err
182
+ }
183
+ recordVal := gobBuf .Bytes ()[:]
184
+ if err := txn .Set ([]byte (key ), recordVal ); err != nil {
184
185
return err
185
186
}
186
187
logger .Debug (
187
188
fmt .Sprintf (
188
- "added record for domain %s: %s: %s" ,
189
+ "added record for domain %s: %s: %s: %s " ,
189
190
domainName ,
190
- nameServer ,
191
- ipAddress ,
191
+ record .Type ,
192
+ record .Lhs ,
193
+ record .Rhs ,
192
194
),
193
195
)
194
196
}
197
+ // Delete old records in tracking key that are no longer present after this update
198
+ domainRecordsKey := []byte (fmt .Sprintf ("d_%s_records" , domainName ))
199
+ domainRecordsItem , err := txn .Get (domainRecordsKey )
200
+ if err != nil {
201
+ if ! errors .Is (err , badger .ErrKeyNotFound ) {
202
+ return err
203
+ }
204
+ } else {
205
+ domainRecordsVal , err := domainRecordsItem .ValueCopy (nil )
206
+ if err != nil {
207
+ return err
208
+ }
209
+ domainRecordsSplit := strings .Split (string (domainRecordsVal ), "," )
210
+ for _ , tmpRecordKey := range domainRecordsSplit {
211
+ if ! slices .Contains (recordKeys , tmpRecordKey ) {
212
+ if err := txn .Delete ([]byte (tmpRecordKey )); err != nil {
213
+ return err
214
+ }
215
+ }
216
+ }
217
+ }
218
+ // Update tracking key with new record keys
219
+ recordKeysJoin := strings .Join (recordKeys , "," )
220
+ if err := txn .Set (domainRecordsKey , []byte (recordKeysJoin )); err != nil {
221
+ return err
222
+ }
195
223
return nil
196
224
})
197
225
return err
198
226
}
199
227
200
- func (s * State ) LookupDomain ( domainName string ) (map [ string ] string , error ) {
201
- ret := map [ string ] string {}
202
- keyPrefix := [] byte ( fmt . Sprintf ( "domain_%s_nameserver_" , domainName ) )
228
+ func (s * State ) LookupRecords ( recordTypes [] string , recordName string ) ([] DomainRecord , error ) {
229
+ ret := [] DomainRecord {}
230
+ recordName = strings . Trim ( recordName , `.` )
203
231
err := s .db .View (func (txn * badger.Txn ) error {
204
- it := txn .NewIterator (badger .DefaultIteratorOptions )
205
- defer it .Close ()
206
- for it .Seek (keyPrefix ); it .ValidForPrefix (keyPrefix ); it .Next () {
207
- item := it .Item ()
208
- k := item .Key ()
209
- keyParts := strings .Split (string (k ), "_" )
210
- nameServer := keyParts [len (keyParts )- 1 ]
211
- err := item .Value (func (v []byte ) error {
212
- ret [nameServer ] = string (v )
213
- return nil
214
- })
215
- if err != nil {
216
- return err
232
+ for _ , recordType := range recordTypes {
233
+ keyPrefix := []byte (fmt .Sprintf ("r_%s_%s_" , strings .ToUpper (recordType ), recordName ))
234
+ it := txn .NewIterator (badger .DefaultIteratorOptions )
235
+ defer it .Close ()
236
+ for it .Seek (keyPrefix ); it .ValidForPrefix (keyPrefix ); it .Next () {
237
+ item := it .Item ()
238
+ val , err := item .ValueCopy (nil )
239
+ if err != nil {
240
+ return err
241
+ }
242
+ gobBuf := bytes .NewReader (val )
243
+ gobDec := gob .NewDecoder (gobBuf )
244
+ var tmpRecord DomainRecord
245
+ if err := gobDec .Decode (& tmpRecord ); err != nil {
246
+ return err
247
+ }
248
+ ret = append (ret , tmpRecord )
217
249
}
218
250
}
219
251
return nil
0 commit comments