diff --git a/addressfilter/address_checker.go b/addressfilter/address_checker.go new file mode 100644 index 0000000000..f5e6af68e1 --- /dev/null +++ b/addressfilter/address_checker.go @@ -0,0 +1,136 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package addressfilter + +import ( + "context" + "sync" + "sync/atomic" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/state" + + "github.com/offchainlabs/nitro/util/stopwaiter" +) + +// Default parameters for HashedAddressChecker, used in NewDefaultHashedAddressChecker +const ( + defaultRestrictedAddrWorkerCount = 4 + defaultRestrictedAddrQueueSize = 8192 +) + +// HashedAddressChecker is a global, shared address checker that filters +// transactions using a HashStore. Hashing and caching are delegated to +// the HashStore; this checker only manages async execution and per-tx +// aggregation. +type HashedAddressChecker struct { + stopwaiter.StopWaiter + store *HashStore + workChan chan workItem + workerCount int +} + +// HashedAddressCheckerState tracks address filtering for a single transaction. +// It aggregates asynchronous checks initiated by TouchAddress and blocks +// in IsFiltered until all submitted checks complete. +type HashedAddressCheckerState struct { + checker *HashedAddressChecker + filtered atomic.Bool + pending sync.WaitGroup +} + +type workItem struct { + addr common.Address + state *HashedAddressCheckerState +} + +// NewHashedAddressChecker constructs a new checker backed by a HashStore. +func NewHashedAddressChecker( + store *HashStore, + workerCount int, + queueSize int, +) *HashedAddressChecker { + if store == nil { + panic("HashStore cannot be nil") + } + + c := &HashedAddressChecker{ + store: store, + workChan: make(chan workItem, queueSize), + workerCount: workerCount, + } + + return c +} + +func (c *HashedAddressChecker) Start(ctx context.Context) { + c.StopWaiter.Start(ctx, c) + + for i := 0; i < c.workerCount; i++ { + c.LaunchThread(func(ctx context.Context) { + c.worker(ctx) + }) + } +} + +func NewDefaultHashedAddressChecker(store *HashStore) *HashedAddressChecker { + return NewHashedAddressChecker( + store, + defaultRestrictedAddrWorkerCount, + defaultRestrictedAddrQueueSize, + ) +} + +func (c *HashedAddressChecker) NewTxState() state.AddressCheckerState { + return &HashedAddressCheckerState{ + checker: c, + } +} + +func (c *HashedAddressChecker) processAddress(addr common.Address, state *HashedAddressCheckerState) { + restricted := c.store.IsRestricted(addr) + state.report(restricted) +} + +// worker runs for the lifetime of the checker; workChan is never closed. +func (c *HashedAddressChecker) worker(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case item := <-c.workChan: + c.processAddress(item.addr, item.state) + } + } +} + +func (s *HashedAddressCheckerState) TouchAddress(addr common.Address) { + s.pending.Add(1) + + // If the checker is stopped, process synchronously + if s.checker.Stopped() { + s.checker.processAddress(addr, s) + return + } + + select { + case s.checker.workChan <- workItem{addr: addr, state: s}: + // ok + default: + // queue full: process synchronously to avoid dropping + s.checker.processAddress(addr, s) + } +} + +func (s *HashedAddressCheckerState) report(filtered bool) { + if filtered { + s.filtered.Store(true) + } + s.pending.Done() +} + +func (s *HashedAddressCheckerState) IsFiltered() bool { + s.pending.Wait() + return s.filtered.Load() +} diff --git a/addressfilter/address_checker_test.go b/addressfilter/address_checker_test.go new file mode 100644 index 0000000000..032e5fa7aa --- /dev/null +++ b/addressfilter/address_checker_test.go @@ -0,0 +1,140 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package addressfilter + +import ( + "context" + "crypto/sha256" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/ethereum/go-ethereum/common" +) + +func mustState(t *testing.T, s any) *HashedAddressCheckerState { + t.Helper() + state, ok := s.(*HashedAddressCheckerState) + require.Truef(t, ok, "unexpected AddressCheckerState type %T", s) + return state +} + +func TestHashedAddressCheckerSimple(t *testing.T) { + salt := []byte("test-salt") + + addrFiltered := common.HexToAddress("0x000000000000000000000000000000000000dead") + addrAllowed := common.HexToAddress("0x000000000000000000000000000000000000beef") + + store := NewHashStore() + + hash := sha256.Sum256(append(salt, addrFiltered.Bytes()...)) + store.Load(salt, []common.Hash{hash}, "test") + + checker := NewDefaultHashedAddressChecker(store) + checker.Start(context.Background()) + + // Tx 1: filtered address + state1 := mustState(t, checker.NewTxState()) + state1.TouchAddress(addrFiltered) + assert.True(t, state1.IsFiltered(), "expected transaction to be filtered") + + // Tx 2: allowed address + state2 := mustState(t, checker.NewTxState()) + state2.TouchAddress(addrAllowed) + assert.False(t, state2.IsFiltered(), "expected transaction NOT to be filtered") + + // Tx 3: mixed addresses + state3 := mustState(t, checker.NewTxState()) + state3.TouchAddress(addrAllowed) + state3.TouchAddress(addrFiltered) + assert.True(t, state3.IsFiltered(), "expected transaction with mixed addresses to be filtered") + + // Tx 4: reuse HashStore cache across txs + state4 := mustState(t, checker.NewTxState()) + state4.TouchAddress(addrFiltered) + assert.True(t, state4.IsFiltered(), "expected cached filtered address to still be filtered") + + // Tx 5: queue overflow should not panic and must be conservative + overflowChecker := NewHashedAddressChecker( + store, + /* workerCount */ 1, + /* queueSize */ 0, + ) + + // Tx 5: synchronous call + overflowState := mustState(t, overflowChecker.NewTxState()) + overflowState.TouchAddress(addrFiltered) + + assert.True( + t, + overflowState.IsFiltered(), + "expected cached filtered address to still be filtered", + ) +} + +func TestHashedAddressCheckerHeavy(t *testing.T) { + salt := []byte("heavy-salt") + + const filteredCount = 500 + filteredAddrs := make([]common.Address, filteredCount) + filteredHashes := make([]common.Hash, filteredCount) + + for i := range filteredAddrs { + addr := common.BytesToAddress([]byte{byte(i + 1)}) + filteredAddrs[i] = addr + filteredHashes[i] = sha256.Sum256(append(salt, addr.Bytes()...)) + } + + store := NewHashStore() + store.Load(salt, filteredHashes, "heavy") + + checker := NewDefaultHashedAddressChecker(store) + checker.Start(context.Background()) + + const txCount = 100 + const touchesPerTx = 100 + + results := make(chan bool, txCount) + + var wg sync.WaitGroup + wg.Add(txCount) + + for tx := range txCount { + go func(tx int) { + defer wg.Done() + + state := mustState(t, checker.NewTxState()) + + for i := range touchesPerTx { + if i%10 == 0 { + state.TouchAddress(filteredAddrs[i%filteredCount]) + } else { + addr := common.BytesToAddress([]byte{byte(200 + i*tx)}) + state.TouchAddress(addr) + } + } + + results <- state.IsFiltered() + }(tx) + } + + wg.Wait() + close(results) + + filteredTxs := 0 + for r := range results { + if r { + filteredTxs++ + } + } + + assert.Greater( + t, + filteredTxs, + 0, + "expected at least some transactions to be filtered under load", + ) +} diff --git a/addressfilter/config.go b/addressfilter/config.go new file mode 100644 index 0000000000..2b89366635 --- /dev/null +++ b/addressfilter/config.go @@ -0,0 +1,46 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package addressfilter + +import ( + "errors" + "time" + + "github.com/spf13/pflag" + + "github.com/offchainlabs/nitro/util/s3syncer" +) + +type Config struct { + Enable bool `koanf:"enable"` + S3 s3syncer.Config `koanf:"s3"` + PollInterval time.Duration `koanf:"poll-interval"` +} + +var DefaultConfig = Config{ + Enable: false, + PollInterval: 5 * time.Minute, +} + +func ConfigAddOptions(prefix string, f *pflag.FlagSet) { + f.Bool(prefix+".enable", DefaultConfig.Enable, "enable restricted address synchronization service") + s3syncer.ConfigAddOptions(prefix+".s3", f) + f.Duration(prefix+".poll-interval", DefaultConfig.PollInterval, "interval between polling S3 for hash list updates") +} + +func (c *Config) Validate() error { + if !c.Enable { + return nil + } + + if err := c.S3.Validate(); err != nil { + return err + } + + if c.PollInterval <= 0 { + return errors.New("restricted-addr.poll-interval must be positive") + } + + return nil +} diff --git a/addressfilter/hash_store.go b/addressfilter/hash_store.go new file mode 100644 index 0000000000..9e5f1ecc80 --- /dev/null +++ b/addressfilter/hash_store.go @@ -0,0 +1,122 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package addressfilter + +import ( + "crypto/sha256" + "sync/atomic" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/lru" +) + +// hashData holds the immutable hash list data. +// Once created, this struct is never modified, making it safe for concurrent reads. +// The cache is included here so it gets swapped atomically with the hash data. +type hashData struct { + salt []byte + hashes map[common.Hash]struct{} + digest string + loadedAt time.Time + cache *lru.Cache[common.Address, bool] // LRU cache for address lookup results +} + +// HashStore provides thread-safe access to restricted address hashes. +// It uses atomic.Pointer for lock-free reads during updates, implementing +// a double-buffering strategy where new data is prepared in the background +// and then atomically swapped in. +type HashStore struct { + data atomic.Pointer[hashData] + cacheSize int +} + +const defaultCacheSize = 10000 + +func NewHashStore() *HashStore { + return NewHashStoreWithCacheSize(defaultCacheSize) +} + +func NewHashStoreWithCacheSize(cacheSize int) *HashStore { + h := &HashStore{ + cacheSize: cacheSize, + } + h.data.Store(&hashData{ + hashes: make(map[common.Hash]struct{}), + cache: lru.NewCache[common.Address, bool](cacheSize), + }) + return h +} + +// Load atomically swaps in a new hash list. +// This is called after a new hash list has been downloaded and parsed. +// A new LRU cache is created for the new data, ensuring atomic consistency. +func (h *HashStore) Load(salt []byte, hashes []common.Hash, digest string) { + newData := &hashData{ + salt: salt, + hashes: make(map[common.Hash]struct{}, len(hashes)), + digest: digest, + loadedAt: time.Now(), + cache: lru.NewCache[common.Address, bool](h.cacheSize), + } + for _, hash := range hashes { + newData.hashes[hash] = struct{}{} + } + h.data.Store(newData) // Atomic pointer swap +} + +// IsRestricted checks if an address is in the restricted list. +// Results are cached in the LRU cache for faster subsequent lookups. +// This method is safe to call concurrently. +func (h *HashStore) IsRestricted(addr common.Address) bool { + data := h.data.Load() // Atomic load - no lock needed + if len(data.salt) == 0 { + return false // Not initialized + } + + // Check cache first (cache is per-data snapshot) + if restricted, ok := data.cache.Get(addr); ok { + return restricted + } + + saltedAddr := make([]byte, len(data.salt)+common.AddressLength) + copy(saltedAddr, data.salt) + copy(saltedAddr[len(data.salt):], addr.Bytes()) + saltedHash := sha256.Sum256(saltedAddr) + + _, restricted := data.hashes[saltedHash] + + // Cache the result + data.cache.Add(addr, restricted) + return restricted +} + +// Digest Return the digest of the current loaded hashstore. +func (h *HashStore) Digest() string { + return h.data.Load().digest +} + +func (h *HashStore) Size() int { + return len(h.data.Load().hashes) +} + +func (h *HashStore) LoadedAt() time.Time { + return h.data.Load().loadedAt +} + +// Salt returns a copy of the current salt. +func (h *HashStore) Salt() []byte { + data := h.data.Load() + if len(data.salt) == 0 { + return nil + } + salt := make([]byte, len(data.salt)) + copy(salt, data.salt) + return salt +} + +// CacheLen returns the current number of entries in the LRU cache. +func (h *HashStore) CacheLen() int { + return h.data.Load().cache.Len() +} diff --git a/addressfilter/s3_sync.go b/addressfilter/s3_sync.go new file mode 100644 index 0000000000..b7fa8a927f --- /dev/null +++ b/addressfilter/s3_sync.go @@ -0,0 +1,95 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package addressfilter + +import ( + "context" + "encoding/hex" + "encoding/json" + "fmt" + + "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/log" + + "github.com/offchainlabs/nitro/util/s3syncer" +) + +// hashListPayload represents the JSON structure of the hash list file used for unmarshalling. +type hashListPayload struct { + Salt string `json:"salt"` + AddressHashes []struct { + Hash string `json:"hash"` + } `json:"address_hashes"` +} + +type S3SyncManager struct { + Syncer *s3syncer.Syncer + store *HashStore +} + +func NewS3SyncManager(ctx context.Context, config *Config, store *HashStore) (*S3SyncManager, error) { + s := &S3SyncManager{ + store: store, + } + syncer, err := s3syncer.NewSyncer( + ctx, + &config.S3, + s.handleHashListData, + // These are initial settings that can be tuned as needed. + s3syncer.WithDownloadConfig(s3syncer.DownloadConfig{ + PartSizeMB: 100, + Concurrency: 10, + PartBodyMaxRetries: 5, + })) + + if err != nil { + return nil, err + } + + s.Syncer = syncer + return s, nil +} + +// handleHashListData parses the downloaded JSON data and loads it into the store. +func (s *S3SyncManager) handleHashListData(data []byte, digest string) error { + salt, hashes, err := parseHashListJSON(data) + if err != nil { + return fmt.Errorf("failed to parse hash list: %w", err) + } + + s.store.Load(salt, hashes, digest) + log.Info("loaded restricted addr list", "hash_count", len(hashes), "etag", digest, "size_bytes", len(data)) + return nil +} + +// parseHashListJSON parses the JSON hash list file. +// Expected format: {"salt": "hex...", "address_hashes": [{"hash": "hex1"}, {"hash": "hex2"}, ...]} +func parseHashListJSON(data []byte) ([]byte, []common.Hash, error) { + var payload hashListPayload + if err := json.Unmarshal(data, &payload); err != nil { + return nil, nil, fmt.Errorf("JSON unmarshal failed: %w", err) + } + + salt, err := hex.DecodeString(payload.Salt) + if err != nil { + return nil, nil, fmt.Errorf("invalid salt hex: %w", err) + } + if len(salt) == 0 { + return nil, nil, fmt.Errorf("salt cannot be empty") + } + + hashes := make([]common.Hash, len(payload.AddressHashes)) + for i, h := range payload.AddressHashes { + hashBytes, err := hex.DecodeString(h.Hash) + if err != nil { + return nil, nil, fmt.Errorf("invalid hash hex at index %d: %w", i, err) + } + if len(hashBytes) != 32 { + return nil, nil, fmt.Errorf("invalid hash length at index %d: got %d, want 32", i, len(hashBytes)) + } + copy(hashes[i][:], hashBytes) + } + + return salt, hashes, nil +} diff --git a/addressfilter/service.go b/addressfilter/service.go new file mode 100644 index 0000000000..9c4f55ebbd --- /dev/null +++ b/addressfilter/service.go @@ -0,0 +1,116 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package addressfilter + +import ( + "context" + "fmt" + "time" + + "github.com/ethereum/go-ethereum/log" + + "github.com/offchainlabs/nitro/util/stopwaiter" +) + +// Service manages the address-filteress synchronization pipeline. +// It periodically polls S3 for hash list updates and maintains an in-memory +// copy for efficient address filtering. +type Service struct { + stopwaiter.StopWaiter + config *Config + store *HashStore + syncMgr *S3SyncManager +} + +// NewService creates a new address-filteress service. +// Returns nil if the service is not enabled in the configuration. +func NewService(ctx context.Context, config *Config) (*Service, error) { + if !config.Enable { + return nil, nil + } + + if err := config.Validate(); err != nil { + return nil, fmt.Errorf("invalid config: %w", err) + } + + store := NewHashStore() + syncMgr, err := NewS3SyncManager(ctx, config, store) + if err != nil { + return nil, fmt.Errorf("failed to create S3 syncer: %w", err) + } + + return &Service{ + config: config, + store: store, + syncMgr: syncMgr, + }, nil +} + +// Initialize downloads the initial hash list from S3. +// This method blocks until the hash list is successfully loaded. +// If this fails, the node should not start. +func (s *Service) Initialize(ctx context.Context) error { + log.Info("initializing address-filter service, downloading initial hash list", + "bucket", s.config.S3.Bucket, + "key", s.config.S3.ObjectKey, + ) + + // Force download (ignore ETag check for initial load) + if err := s.syncMgr.Syncer.DownloadAndLoad(ctx); err != nil { + return fmt.Errorf("failed to load initial hash list: %w", err) + } + + log.Info("address-filter service initialized", + "hash_count", s.store.Size(), + "etag-digest", s.store.Digest(), + ) + return nil +} + +// Start begins the background polling goroutine. +// This should be called after Initialize() succeeds. +func (s *Service) Start(ctx context.Context) { + s.StopWaiter.Start(ctx, s) + + // Start periodic polling goroutine + s.CallIteratively(func(ctx context.Context) time.Duration { + if err := s.syncMgr.Syncer.CheckAndSync(ctx); err != nil { + log.Error("failed to sync address-filter list", "err", err) + } + return s.config.PollInterval + }) + + log.Info("address-filter service started", + "poll_interval", s.config.PollInterval, + ) +} + +func (s *Service) GetHashCount() int { + if !s.config.Enable { + return 0 + } + return s.store.Size() +} + +// GetHashStoreDigest GetETag returns the S3 ETag Digest of the currently loaded hash list. +func (s *Service) GetHashStoreDigest() string { + if !s.config.Enable { + return "" + } + return s.store.Digest() +} + +func (s *Service) GetLoadedAt() time.Time { + if !s.config.Enable { + return time.Time{} + } + return s.store.LoadedAt() +} + +func (s *Service) GetHashStore() *HashStore { + if !s.config.Enable { + return nil + } + return s.store +} diff --git a/addressfilter/service_test.go b/addressfilter/service_test.go new file mode 100644 index 0000000000..fa9d6160ad --- /dev/null +++ b/addressfilter/service_test.go @@ -0,0 +1,343 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package addressfilter + +import ( + "crypto/sha256" + "encoding/hex" + "encoding/json" + "sync" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common" + + "github.com/offchainlabs/nitro/util/s3syncer" +) + +func TestHashStore_IsRestricted(t *testing.T) { + store := NewHashStore() + + // Test empty store + addr := common.HexToAddress("0x1234567890123456789012345678901234567890") + if store.IsRestricted(addr) { + t.Error("empty store should not restrict any address") + } + + // Create test data + salt := []byte("test-salt") + addresses := []common.Address{ + common.HexToAddress("0x1111111111111111111111111111111111111111"), + common.HexToAddress("0x2222222222222222222222222222222222222222"), + common.HexToAddress("0x3333333333333333333333333333333333333333"), + } + + // Pre-compute hashes + hashes := make([]common.Hash, 0, len(addresses)) + for _, addr := range addresses { + hash := sha256.Sum256(append(salt, addr.Bytes()...)) + hashes = append(hashes, hash) + } + + // Load the hashes + store.Load(salt, hashes, "test-etag") + + // Test restricted addresses + for _, addr := range addresses { + if !store.IsRestricted(addr) { + t.Errorf("address %s should be restricted", addr.Hex()) + } + } + + // Test non-restricted address + nonRestrictedAddr := common.HexToAddress("0x4444444444444444444444444444444444444444") + if store.IsRestricted(nonRestrictedAddr) { + t.Errorf("address %s should not be restricted", nonRestrictedAddr.Hex()) + } + + // Test metadata + if store.Digest() != "test-etag" { + t.Errorf("expected etag 'test-etag', got '%s'", store.Digest()) + } + if store.Size() != 3 { + t.Errorf("expected size 3, got %d", store.Size()) + } +} + +func TestHashStore_AtomicSwap(t *testing.T) { + store := NewHashStore() + + salt1 := []byte("salt1") + addr1 := common.HexToAddress("0x1111111111111111111111111111111111111111") + hash1 := sha256.Sum256(append(salt1, addr1.Bytes()...)) + + // Load first set + store.Load(salt1, []common.Hash{hash1}, "etag1") + if !store.IsRestricted(addr1) { + t.Error("addr1 should be restricted after first load") + } + + // Load second set with different salt (simulating hourly rotation) + salt2 := []byte("salt2") + addr2 := common.HexToAddress("0x2222222222222222222222222222222222222222") + hash2 := sha256.Sum256(append(salt2, addr2.Bytes()...)) + + store.Load(salt2, []common.Hash{hash2}, "etag2") + + // addr1 should no longer be restricted (different salt) + if store.IsRestricted(addr1) { + t.Error("addr1 should not be restricted after swap (salt changed)") + } + // addr2 should now be restricted + if !store.IsRestricted(addr2) { + t.Error("addr2 should be restricted after swap") + } + if store.Digest() != "etag2" { + t.Errorf("expected etag 'etag2', got '%s'", store.Digest()) + } +} + +func TestHashStore_ConcurrentAccess(t *testing.T) { + store := NewHashStore() + + salt1 := []byte("test-salt") + var addresses []common.Address + var hashes1 []common.Hash + for i := 0; i < 100; i++ { + addr := common.BigToAddress(common.Big1) + addr[18] = byte(i) + addresses = append(addresses, addr) + hash := sha256.Sum256(append(salt1, addr.Bytes()...)) + hashes1 = append(hashes1, hash) + } + store.Load(salt1, hashes1, "etag") + + // prepare second set for swapping + salt2 := []byte("new-salt") + var addresses2 []common.Address + var hashes2 []common.Hash + for i := 0; i < 100; i++ { + addr := common.BigToAddress(common.Big2) + addr[18] = byte(i) + addresses2 = append(addresses2, addr) + hash := sha256.Sum256(append(salt2, addr.Bytes()...)) + hashes2 = append(hashes2, hash) + } + + // Run concurrent reads + var wg sync.WaitGroup + for p := 0; p < 10; p++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + for i := 0; i < 100; i++ { + addr1 := addresses[i] + addr2 := addresses2[i] + + if store.isAllRestricted([]common.Address{addr1, addr2}) || + !store.isAnyRestricted([]common.Address{addr1, addr2}) { + // One should be restricted, the other not, atomic swap should ensure consistency + t.Log("addr1:", addr1.Hex(), "restricted:", store.IsRestricted(addr1)) + t.Log("addr2:", addr2.Hex(), "restricted:", store.IsRestricted(addr2)) + t.Error("concurrent access yielded inconsistent results") + } + } + } + }() + } + + // Run concurrent swap + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 10; i++ { + if i%2 == 0 { + store.Load(salt1, hashes1, "etag") + } else { + store.Load(salt2, hashes2, "new-etag") + } + time.Sleep(time.Millisecond) + } + }() + + wg.Wait() +} + +func TestParseHashListJSON(t *testing.T) { + hashed_addr1 := sha256.Sum256(common.BigToAddress(common.Big1).Bytes()) + hashed_addr2 := sha256.Sum256(common.BigToAddress(common.Big2).Bytes()) + // Test valid JSON + // should follow format: {"salt": "hex...", "address_hashes": [{"hash": "hex1", "max_risk_score_level":1}, {"hash": "hex2", "max_risk_score_level":3}, ...]} + validPayload := map[string]interface{}{ + "salt": hex.EncodeToString([]byte("test-salt")), + "address_hashes": []map[string]interface{}{ + { + "hash": hex.EncodeToString(hashed_addr1[:]), + "max_risk_score_level": 2, + }, + { + "hash": hex.EncodeToString(hashed_addr2[:]), + "max_risk_score_level": 3, + }, + }, + } + validJSON, _ := json.Marshal(validPayload) + + salt, hashes, err := parseHashListJSON(validJSON) + if err != nil { + t.Fatalf("failed to parse valid JSON: %v", err) + } + if string(salt) != "test-salt" { + t.Errorf("expected salt 'test-salt', got '%s'", string(salt)) + } + if len(hashes) != 2 { + t.Errorf("expected 2 hashes, got %d", len(hashes)) + } + + // Test invalid JSON + _, _, err = parseHashListJSON([]byte("not json")) + if err == nil { + t.Error("expected error for invalid JSON") + } + + // Test invalid salt hex + invalidSaltPayload := map[string]interface{}{ + "salt": "not-hex", + "address_hashes": []map[string]interface{}{{"hash": hex.EncodeToString(hashed_addr1[:])}}, + } + invalidSaltJSON, _ := json.Marshal(invalidSaltPayload) + _, _, err = parseHashListJSON(invalidSaltJSON) + if err == nil { + t.Error("expected error for invalid salt hex") + } + + // Test invalid hash hex + invalidHashPayload := map[string]interface{}{ + "salt": hex.EncodeToString([]byte("test-salt")), + "address_hashes": []map[string]interface{}{{"hash": "not-hex"}}, + } + invalidHashJSON, _ := json.Marshal(invalidHashPayload) + _, _, err = parseHashListJSON(invalidHashJSON) + if err == nil { + t.Error("expected error for invalid hash hex") + } + + // Test wrong hash length + wrongLenPayload := map[string]interface{}{ + "salt": hex.EncodeToString([]byte("test-salt")), + "address_hashes": []map[string]interface{}{{"hash": "0123456789abcdef"}}, + } + wrongLenJSON, _ := json.Marshal(wrongLenPayload) + _, _, err = parseHashListJSON(wrongLenJSON) + if err == nil { + t.Error("expected error for wrong hash length") + } +} + +func TestConfig_Validate(t *testing.T) { + // Test disabled config (should always be valid) + disabledConfig := Config{Enable: false} + if err := disabledConfig.Validate(); err != nil { + t.Errorf("disabled config should be valid: %v", err) + } + + // Test enabled config with missing fields + enabledConfig := Config{Enable: true} + if err := enabledConfig.Validate(); err == nil { + t.Error("enabled config with missing fields should be invalid") + } + + // Test valid enabled config + validConfig := Config{ + Enable: true, + S3: s3syncer.Config{ + Bucket: "test-bucket", + Region: "us-east-1", + ObjectKey: "hashlists/current.json", + }, + PollInterval: 5 * time.Minute, + } + if err := validConfig.Validate(); err != nil { + t.Errorf("valid config should pass validation: %v", err) + } + + // Test invalid poll interval + invalidPollConfig := validConfig + invalidPollConfig.PollInterval = 0 + if err := invalidPollConfig.Validate(); err == nil { + t.Error("config with zero poll interval should be invalid") + } +} + +func TestHashStore_LoadedAt(t *testing.T) { + store := NewHashStore() + + // Empty store should have zero time + if !store.LoadedAt().IsZero() { + t.Error("empty store should have zero LoadedAt") + } + + // After load, should have current time + before := time.Now() + store.Load([]byte("salt"), nil, "etag") + after := time.Now() + + loadedAt := store.LoadedAt() + if loadedAt.Before(before) || loadedAt.After(after) { + t.Errorf("LoadedAt should be between %v and %v, got %v", before, after, loadedAt) + } +} + +// IsAllRestricted checks if all provided addresses are in the restricted list +// from same hash-store snapshot. Results are cached in the LRU cache. +func (h *HashStore) isAllRestricted(addrs []common.Address) bool { + data := h.data.Load() // Atomic load - no lock needed + if len(data.salt) == 0 { + return false // Not initialized + } + for _, addr := range addrs { + // Check cache first (cache is per-data snapshot) + if restricted, ok := data.cache.Get(addr); ok { + if !restricted { + return false + } + continue + } + + hash := sha256.Sum256(append(data.salt, addr.Bytes()...)) + _, restricted := data.hashes[hash] + data.cache.Add(addr, restricted) + if !restricted { + return false + } + } + return true +} + +// IsAnyRestricted checks if any of the provided addresses are in the restricted list +// from same hash-store snapshot. Results are cached in the LRU cache. +func (h *HashStore) isAnyRestricted(addrs []common.Address) bool { + data := h.data.Load() // Atomic load - no lock needed + if len(data.salt) == 0 { + return false // Not initialized + } + for _, addr := range addrs { + // Check cache first (cache is per-data snapshot) + if restricted, ok := data.cache.Get(addr); ok { + if restricted { + return true + } + continue + } + + hash := sha256.Sum256(append(data.salt, addr.Bytes()...)) + _, restricted := data.hashes[hash] + data.cache.Add(addr, restricted) + if restricted { + return true + } + } + return false +} diff --git a/changelog/mnasr-compliance-chains-restricted-addr-NIT-4269.md b/changelog/mnasr-compliance-chains-restricted-addr-NIT-4269.md new file mode 100644 index 0000000000..fd4750027c --- /dev/null +++ b/changelog/mnasr-compliance-chains-restricted-addr-NIT-4269.md @@ -0,0 +1,17 @@ +### Added +- Add address filter service for compliance chains (`addressfilter` package). This feature enables sequencers to block transactions involving filtered addresses by polling a hashed address list from S3. Key capabilities include: + - S3-based hashed list synchronization with ETag change detection for efficient polling + - Lock-free HashStore using atomic pointer swaps for zero-blocking reads during updates + - LRU cache (10k entries) for high-performance address lookups + - Privacy-preserving design: addresses are never stored or transmitted in plaintext (SHA256 with salt) + - Configurable via `--execution.address-filter.*` flags (enable,s3.bucket, s3.region, s3.object-key, s3.AccessKey, s3.SecretKey, poll-interval) + +### Configuration + - Add `--execution.address-filter.enable` flag to enable/disable address filtering + - Add `--execution.address-filter.poll-interval` flag to set the polling interval for the s3 syncer , e.g. 5s + - Add `--execution.address-filter.s3.*` group of flags to configure S3 access: + - Add `--execution.address-filter.s3.bucket` flag to specify the S3 bucket name for the hashed address list + - Add `--execution.address-filter.s3.region` flag to specify the AWS region of + - Add `--execution.address-filter.s3.object-key` flag to specify the S3 object key for the hashed address list + - Add `--execution.address-filter.s3.access-key` flag to specify the AWS access + - Add `--execution.address-filter.s3.secret-key` flag to specify the AWS secret key diff --git a/changelog/mrogachev-nit-4299.md b/changelog/mrogachev-nit-4299.md new file mode 100644 index 0000000000..d17d2b35a5 --- /dev/null +++ b/changelog/mrogachev-nit-4299.md @@ -0,0 +1,2 @@ +### Added +- Hashed address filter implementation for address filter interfaces with shared LRU caching \ No newline at end of file diff --git a/execution/gethexec/node.go b/execution/gethexec/node.go index ba1818b271..7640717991 100644 --- a/execution/gethexec/node.go +++ b/execution/gethexec/node.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/params" "github.com/ethereum/go-ethereum/rpc" + "github.com/offchainlabs/nitro/addressfilter" "github.com/offchainlabs/nitro/arbos/arbostypes" "github.com/offchainlabs/nitro/arbos/programs" "github.com/offchainlabs/nitro/arbutil" @@ -137,6 +138,7 @@ type Config struct { ExposeMultiGas bool `koanf:"expose-multi-gas"` RPCServer rpcserver.Config `koanf:"rpc-server"` ConsensusRPCClient rpcclient.ClientConfig `koanf:"consensus-rpc-client" reload:"hot"` + AddressFilter addressfilter.Config `koanf:"address-filter" reload:"hot"` forwardingTarget string } @@ -168,6 +170,9 @@ func (c *Config) Validate() error { if err := c.ConsensusRPCClient.Validate(); err != nil { return fmt.Errorf("error validating ConsensusRPCClient config: %w", err) } + if err := c.AddressFilter.Validate(); err != nil { + return fmt.Errorf("error validating addressfilter config: %w", err) + } return nil } @@ -191,6 +196,7 @@ func ConfigAddOptions(prefix string, f *pflag.FlagSet) { LiveTracingConfigAddOptions(prefix+".vmtrace", f) rpcserver.ConfigAddOptions(prefix+".rpc-server", "execution", f) rpcclient.RPCClientAddOptions(prefix+".consensus-rpc-client", f, &ConfigDefault.ConsensusRPCClient) + addressfilter.ConfigAddOptions(prefix+".address-filter", f) } type LiveTracingConfig struct { @@ -237,6 +243,8 @@ var ConfigDefault = Config{ ArgLogLimit: 2048, WebsocketMessageSizeLimit: 256 * 1024 * 1024, }, + + AddressFilter: addressfilter.DefaultConfig, } type ConfigFetcher interface { @@ -262,6 +270,7 @@ type ExecutionNode struct { started atomic.Bool bulkBlockMetadataFetcher *BulkBlockMetadataFetcher consensusRPCClient *consensusrpcclient.ConsensusRPCClient + addressFilterService *addressfilter.Service } func CreateExecutionNode( @@ -356,6 +365,11 @@ func CreateExecutionNode( bulkBlockMetadataFetcher := NewBulkBlockMetadataFetcher(l2BlockChain, execEngine, config.BlockMetadataApiCacheSize, config.BlockMetadataApiBlocksLimit) + addressFilterService, err := addressfilter.NewService(ctx, &config.AddressFilter) + if err != nil { + return nil, fmt.Errorf("failed to create restricted addr service: %w", err) + } + execNode := &ExecutionNode{ ExecutionDB: executionDB, Backend: backend, @@ -371,6 +385,7 @@ func CreateExecutionNode( ParentChainReader: parentChainReader, ClassicOutbox: classicOutbox, bulkBlockMetadataFetcher: bulkBlockMetadataFetcher, + addressFilterService: addressFilterService, } if config.ConsensusRPCClient.URL != "" { @@ -463,6 +478,19 @@ func (n *ExecutionNode) Initialize(ctx context.Context) error { return fmt.Errorf("error setting sync backend: %w", err) } + if n.addressFilterService != nil { + if err = n.addressFilterService.Initialize(ctx); err != nil { + return fmt.Errorf("error initializing restricted addr service: %w", err) + } + + hashStore := n.addressFilterService.GetHashStore() + if hashStore != nil { + addressChecker := addressfilter.NewDefaultHashedAddressChecker(hashStore) + addressChecker.Start(ctx) + n.ExecEngine.SetAddressChecker(addressChecker) + } + } + return nil } @@ -496,6 +524,9 @@ func (n *ExecutionNode) Start(ctxIn context.Context) error { n.ParentChainReader.Start(ctx) } n.bulkBlockMetadataFetcher.Start(ctx) + if n.addressFilterService != nil { + n.addressFilterService.Start(ctx) + } return nil } @@ -525,6 +556,9 @@ func (n *ExecutionNode) StopAndWait() { // log.Error("error on stak close", "err", err) // } n.StopWaiter.StopAndWait() + if n.addressFilterService != nil { + n.addressFilterService.StopAndWait() + } } func (n *ExecutionNode) DigestMessage(num arbutil.MessageIndex, msg *arbostypes.MessageWithMetadata, msgForPrefetch *arbostypes.MessageWithMetadata) containers.PromiseInterface[*execution.MessageResult] { diff --git a/system_tests/tx_address_filter_test.go b/system_tests/tx_address_filter_test.go index 10c97b19b4..983ba4f018 100644 --- a/system_tests/tx_address_filter_test.go +++ b/system_tests/tx_address_filter_test.go @@ -5,6 +5,7 @@ package arbtest import ( "context" + "crypto/sha256" "math/big" "strings" "testing" @@ -12,8 +13,8 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" + "github.com/offchainlabs/nitro/addressfilter" "github.com/offchainlabs/nitro/solgen/go/localgen" - "github.com/offchainlabs/nitro/txfilter" ) func isFilteredError(err error) bool { @@ -23,6 +24,24 @@ func isFilteredError(err error) bool { return strings.Contains(err.Error(), "internal error") } +func newHashedChecker(addrs []common.Address) *addressfilter.HashedAddressChecker { + store := addressfilter.NewHashStore() + if len(addrs) > 0 { + salt := []byte("test-salt") + hashes := make([]common.Hash, len(addrs)) + for i, addr := range addrs { + salted := make([]byte, len(salt)+common.AddressLength) + copy(salted, salt) + copy(salted[len(salt):], addr.Bytes()) + hashes[i] = sha256.Sum256(salted) + } + store.Load(salt, hashes, "test") + } + checker := addressfilter.NewDefaultHashedAddressChecker(store) + checker.Start(context.Background()) + return checker +} + func TestAddressFilterDirectTransfer(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -42,7 +61,7 @@ func TestAddressFilterDirectTransfer(t *testing.T) { // Set up address filter to block FilteredUser filteredAddr := builder.L2Info.GetAddress("FilteredUser") - filter := txfilter.NewStaticAsyncChecker([]common.Address{filteredAddr}) + filter := newHashedChecker([]common.Address{filteredAddr}) builder.L2.ExecNode.ExecEngine.SetAddressChecker(filter) // Test 1: Transaction TO a filtered address should fail @@ -105,7 +124,7 @@ func TestAddressFilterCall(t *testing.T) { targetAddr, _ := deployAddressFilterTestContract(t, ctx, builder) // Set up filter to block the target contract - filter := txfilter.NewStaticAsyncChecker([]common.Address{targetAddr}) + filter := newHashedChecker([]common.Address{targetAddr}) builder.L2.ExecNode.ExecEngine.SetAddressChecker(filter) // Test: CALL to filtered address should fail @@ -143,7 +162,7 @@ func TestAddressFilterStaticCall(t *testing.T) { targetAddr, _ := deployAddressFilterTestContract(t, ctx, builder) // Set up filter to block the target contract - filter := txfilter.NewStaticAsyncChecker([]common.Address{targetAddr}) + filter := newHashedChecker([]common.Address{targetAddr}) builder.L2.ExecNode.ExecEngine.SetAddressChecker(filter) // Test: STATICCALL to filtered address within a transaction should fail @@ -180,7 +199,7 @@ func TestAddressFilterDisabled(t *testing.T) { builder.L2.TransferBalance(t, "Owner", "TestUser", big.NewInt(1e18), builder.L2Info) // Set up an empty filter (disabled) - filter := txfilter.NewStaticAsyncChecker([]common.Address{}) + filter := newHashedChecker([]common.Address{}) builder.L2.ExecNode.ExecEngine.SetAddressChecker(filter) // All transactions should succeed when filter is disabled @@ -215,7 +234,7 @@ func TestAddressFilterCreate2(t *testing.T) { Require(t, err) // Set up filter to block the computed address - filter := txfilter.NewStaticAsyncChecker([]common.Address{create2Addr}) + filter := newHashedChecker([]common.Address{create2Addr}) builder.L2.ExecNode.ExecEngine.SetAddressChecker(filter) // Test: CREATE2 to filtered address should fail @@ -257,7 +276,7 @@ func TestAddressFilterCreate(t *testing.T) { createAddr := crypto.CreateAddress(callerAddr, nonce) // Set up filter to block the computed address - filter := txfilter.NewStaticAsyncChecker([]common.Address{createAddr}) + filter := newHashedChecker([]common.Address{createAddr}) builder.L2.ExecNode.ExecEngine.SetAddressChecker(filter) // Test: CREATE to filtered address should fail @@ -272,7 +291,7 @@ func TestAddressFilterCreate(t *testing.T) { // Test: CREATE to non-filtered address (after nonce incremented) should succeed // Clear the filter to allow the next CREATE - emptyChecker := txfilter.NewStaticAsyncChecker([]common.Address{}) + emptyChecker := newHashedChecker([]common.Address{}) builder.L2.ExecNode.ExecEngine.SetAddressChecker(emptyChecker) auth = builder.L2Info.GetDefaultTransactOpts("Owner", ctx) @@ -299,7 +318,7 @@ func TestAddressFilterSelfdestruct(t *testing.T) { filteredAddr := builder.L2Info.GetAddress("FilteredBeneficiary") // Set up filter to block the beneficiary - filter := txfilter.NewStaticAsyncChecker([]common.Address{filteredAddr}) + filter := newHashedChecker([]common.Address{filteredAddr}) builder.L2.ExecNode.ExecEngine.SetAddressChecker(filter) // Test: SELFDESTRUCT to filtered beneficiary should fail diff --git a/txfilter/filter.go b/txfilter/filter.go deleted file mode 100644 index f039328771..0000000000 --- a/txfilter/filter.go +++ /dev/null @@ -1,62 +0,0 @@ -// Copyright 2025-2026, Offchain Labs, Inc. -// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md - -package txfilter - -import ( - "sync" - - "github.com/ethereum/go-ethereum/common" - "github.com/ethereum/go-ethereum/core/state" -) - -// NoopChecker is a stub that filters nothing. -type NoopChecker struct{} - -func (c *NoopChecker) NewTxState() state.AddressCheckerState { - return &noopState{} -} - -type noopState struct{} - -func (s *noopState) TouchAddress(addr common.Address) {} -func (s *noopState) IsFiltered() bool { return false } - -// StaticAsyncChecker filters a fixed set of addresses (for testing). -// Checks addresses asynchronously using goroutines to demonstrate the async pattern. -type StaticAsyncChecker struct { - addresses map[common.Address]struct{} -} - -func NewStaticAsyncChecker(addrs []common.Address) *StaticAsyncChecker { - m := make(map[common.Address]struct{}, len(addrs)) - for _, addr := range addrs { - m[addr] = struct{}{} - } - return &StaticAsyncChecker{addresses: m} -} - -func (c *StaticAsyncChecker) NewTxState() state.AddressCheckerState { - return &staticAsyncState{checker: c} -} - -type staticAsyncState struct { - checker *StaticAsyncChecker - filtered bool - pendingChecks sync.WaitGroup -} - -func (s *staticAsyncState) TouchAddress(addr common.Address) { - s.pendingChecks.Add(1) - go func() { - defer s.pendingChecks.Done() - if _, found := s.checker.addresses[addr]; found { - s.filtered = true - } - }() -} - -func (s *staticAsyncState) IsFiltered() bool { - s.pendingChecks.Wait() - return s.filtered -} diff --git a/util/s3syncer/config.go b/util/s3syncer/config.go new file mode 100644 index 0000000000..94f97896ab --- /dev/null +++ b/util/s3syncer/config.go @@ -0,0 +1,60 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package s3syncer + +import ( + "errors" + + "github.com/spf13/pflag" +) + +// Config holds the S3 configuration for syncing data. +type Config struct { + Bucket string `koanf:"bucket"` + Region string `koanf:"region"` + ObjectKey string `koanf:"object-key"` + AccessKey string `koanf:"access-key"` + SecretKey string `koanf:"secret-key"` +} + +// ConfigAddOptions adds S3 configuration flags to the given flag set. +func ConfigAddOptions(prefix string, f *pflag.FlagSet) { + f.String(prefix+".bucket", "", "S3 bucket name") + f.String(prefix+".region", "", "AWS region of the S3 bucket") + f.String(prefix+".access-key", "", "AWS access key for S3 (optional, uses default credentials if "+ + "not provided which check for credentials in specific order like env variables, shared credentials, etc.)") + f.String(prefix+".secret-key", "", "AWS secret key for S3 (optional, uses default credentials if "+ + "not provided which check for credentials in specific order like env variables, shared credentials, etc.)") + f.String(prefix+".object-key", "", "S3 object key (path) to the file") +} + +// Validate checks that required S3 configuration fields are set. +func (c *Config) Validate() error { + if c.Bucket == "" { + return errors.New("s3 bucket is required") + } + if c.Region == "" { + return errors.New("s3 region is required") + } + if c.ObjectKey == "" { + return errors.New("s3 object-key is required") + } + return nil +} + +// DownloadConfig holds configuration for S3 multipart downloads. +type DownloadConfig struct { + PartSizeMB int `koanf:"part-size-mb"` + PartBodyMaxRetries int `koanf:"part-body-max-retries"` + Concurrency int `koanf:"concurrency"` +} + +// DefaultDownloadConfig returns the default download configuration. +func DefaultDownloadConfig() DownloadConfig { + return DownloadConfig{ + PartSizeMB: 32, + PartBodyMaxRetries: 5, + Concurrency: 10, + } +} diff --git a/util/s3syncer/syncer.go b/util/s3syncer/syncer.go new file mode 100644 index 0000000000..e59c7fd1da --- /dev/null +++ b/util/s3syncer/syncer.go @@ -0,0 +1,153 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package s3syncer + +import ( + "context" + "fmt" + "sync" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + + "github.com/ethereum/go-ethereum/log" + + "github.com/offchainlabs/nitro/util/s3client" +) + +// DataHandler processes downloaded data and the associated digest. +type DataHandler func(data []byte, digest string) error + +// Syncer handles S3 object syncing with ETag-based change detection. +type Syncer struct { + client s3client.FullClient + config *Config + downloadConfig DownloadConfig + handleData DataHandler + digestETag string + mutex sync.Mutex +} + +// Option configures a Syncer. +type Option func(*Syncer) + +const bytesInMB = 1024 * 1024 + +// WithDownloadConfig sets custom download configuration. +func WithDownloadConfig(dc DownloadConfig) Option { + return func(s *Syncer) { + s.downloadConfig = dc + } +} + +// WithS3Client sets a custom S3 client (useful for testing). +func WithS3Client(client s3client.FullClient) Option { + return func(s *Syncer) { + s.client = client + } +} + +// NewSyncer creates a new S3 syncer with the given callbacks. +func NewSyncer( + ctx context.Context, + config *Config, + dataHandler DataHandler, + opts ...Option, +) (*Syncer, error) { + s := &Syncer{ + config: config, + downloadConfig: DefaultDownloadConfig(), + handleData: dataHandler, + } + + for _, opt := range opts { + opt(s) + } + + // Create S3 client if not provided via option + if s.client == nil { + client, err := s3client.NewS3FullClient(ctx, config.AccessKey, config.SecretKey, config.Region) + if err != nil { + return nil, fmt.Errorf("failed to create S3 client: %w", err) + } + s.client = client + } + + return s, nil +} + +// CheckAndSync checks if the S3 object has changed (via ETag) and downloads it if so. +func (s *Syncer) CheckAndSync(ctx context.Context) error { + s.mutex.Lock() + defer s.mutex.Unlock() + headOutput, err := s.client.Client().HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(s.config.Bucket), + Key: aws.String(s.config.ObjectKey), + }) + if err != nil { + return fmt.Errorf("HeadObject failed for s3://%s/%s: %w", s.config.Bucket, s.config.ObjectKey, err) + } + + currentETag := aws.ToString(headOutput.ETag) + + // Compare with stored digest + if currentETag == s.digestETag { + log.Debug("S3 object unchanged", "etag", currentETag, "bucket", s.config.Bucket, "key", s.config.ObjectKey) + return nil + } + + log.Info("S3 object changed, downloading", + "old_etag", s.digestETag, + "new_etag", currentETag, + "bucket", s.config.Bucket, + "key", s.config.ObjectKey, + ) + objectSize := aws.ToInt64(headOutput.ContentLength) + return s.downloadAndHandle(ctx, currentETag, objectSize) +} + +// DownloadAndLoad downloads the S3 object and processes it with the data handler. +// This is used for initial load where we need to fetch metadata first. +func (s *Syncer) DownloadAndLoad(ctx context.Context) error { + headOutput, err := s.client.Client().HeadObject(ctx, &s3.HeadObjectInput{ + Bucket: aws.String(s.config.Bucket), + Key: aws.String(s.config.ObjectKey), + }) + if err != nil { + return fmt.Errorf("HeadObject failed for s3://%s/%s: %w", s.config.Bucket, s.config.ObjectKey, err) + } + + newETagDigest := aws.ToString(headOutput.ETag) + objectSize := aws.ToInt64(headOutput.ContentLength) + err = s.downloadAndHandle(ctx, newETagDigest, objectSize) + return err +} + +// downloadAndHandle downloads the S3 object to a temp file and calls the data handler. +func (s *Syncer) downloadAndHandle(ctx context.Context, etagDigest string, objectSize int64) error { + downloader := manager.NewDownloader(s.client.Client(), func(d *manager.Downloader) { + d.PartSize = int64(s.downloadConfig.PartSizeMB) * bytesInMB + d.PartBodyMaxRetries = s.downloadConfig.PartBodyMaxRetries + d.Concurrency = s.downloadConfig.Concurrency + }) + + // let's use an in-memory buffer to avoid file I/O + buffer := manager.NewWriteAtBuffer(make([]byte, 0, objectSize)) + + // Download - SDK handles chunking, concurrency, and retry + _, err := downloader.Download(ctx, buffer, &s3.GetObjectInput{ + Bucket: aws.String(s.config.Bucket), + Key: aws.String(s.config.ObjectKey), + }) + if err != nil { + return fmt.Errorf("download failed for s3://%s/%s: %w", s.config.Bucket, s.config.ObjectKey, err) + } + + err = s.handleData(buffer.Bytes(), etagDigest) + if err == nil { + s.digestETag = etagDigest + } + return err +} diff --git a/util/s3syncer/syncer_test.go b/util/s3syncer/syncer_test.go new file mode 100644 index 0000000000..86b6303eec --- /dev/null +++ b/util/s3syncer/syncer_test.go @@ -0,0 +1,84 @@ +// Copyright 2026, Offchain Labs, Inc. +// For license information, see https://github.com/OffchainLabs/nitro/blob/master/LICENSE.md + +package s3syncer + +import ( + "testing" +) + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config Config + wantErr bool + }{ + { + name: "valid config", + config: Config{ + Bucket: "test-bucket", + Region: "us-east-1", + ObjectKey: "path/to/file.json", + }, + wantErr: false, + }, + { + name: "missing bucket", + config: Config{ + Region: "us-east-1", + ObjectKey: "path/to/file.json", + }, + wantErr: true, + }, + { + name: "missing region", + config: Config{ + Bucket: "test-bucket", + ObjectKey: "path/to/file.json", + }, + wantErr: true, + }, + { + name: "missing object key", + config: Config{ + Bucket: "test-bucket", + Region: "us-east-1", + }, + wantErr: true, + }, + { + name: "valid config with credentials", + config: Config{ + Bucket: "test-bucket", + Region: "us-east-1", + ObjectKey: "path/to/file.json", + AccessKey: "AKIAIOSFODNN7EXAMPLE", + SecretKey: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestDefaultDownloadConfig(t *testing.T) { + dc := DefaultDownloadConfig() + + if dc.PartSizeMB != 32 { + t.Errorf("DefaultDownloadConfig().PartSizeMB = %d, want 32", dc.PartSizeMB) + } + if dc.PartBodyMaxRetries != 5 { + t.Errorf("DefaultDownloadConfig().PartBodyMaxRetries = %d, want 5", dc.PartBodyMaxRetries) + } + if dc.Concurrency != 10 { + t.Errorf("DefaultDownloadConfig().Concurrency = %d, want 10", dc.Concurrency) + } +}