@@ -592,6 +592,131 @@ func TestServerMissingFeedServerVersion(t *testing.T) {
592592 }
593593}
594594
595+ type accumulatingTransactionStreamer struct {
596+ mu sync.Mutex
597+ messages []* message.BroadcastFeedMessage
598+ }
599+
600+ func (ts * accumulatingTransactionStreamer ) AddBroadcastMessages (feedMessages []* message.BroadcastFeedMessage ) error {
601+ ts .mu .Lock ()
602+ defer ts .mu .Unlock ()
603+ ts .messages = append (ts .messages , feedMessages ... )
604+ return nil
605+ }
606+
607+ func (ts * accumulatingTransactionStreamer ) getMessages () []* message.BroadcastFeedMessage {
608+ ts .mu .Lock ()
609+ defer ts .mu .Unlock ()
610+ result := make ([]* message.BroadcastFeedMessage , len (ts .messages ))
611+ copy (result , ts .messages )
612+ return result
613+ }
614+
615+ // awaitCount waits until at least count messages have been received.
616+ // The timeout is a safety net to prevent the test from hanging.
617+ func (ts * accumulatingTransactionStreamer ) awaitCount (t * testing.T , count int , timeout time.Duration ) {
618+ t .Helper ()
619+ deadline := time .After (timeout )
620+ for {
621+ if len (ts .getMessages ()) >= count {
622+ return
623+ }
624+ select {
625+ case <- deadline :
626+ t .Fatalf ("timed out waiting for %d messages, got %d" , count , len (ts .getMessages ()))
627+ case <- time .After (10 * time .Millisecond ):
628+ }
629+ }
630+ }
631+
632+ func TestInvalidSignatureMessagesAreSkipped (t * testing.T ) {
633+ t .Parallel ()
634+ ctx , cancel := context .WithCancel (context .Background ())
635+ defer cancel ()
636+
637+ chainId := uint64 (9742 )
638+
639+ // Trusted key: broadcaster signs with this, client trusts this
640+ trustedKey , err := crypto .GenerateKey ()
641+ Require (t , err )
642+ trustedAddr := crypto .PubkeyToAddress (trustedKey .PublicKey )
643+ trustedSigner := signature .DataSignerFromPrivateKey (trustedKey )
644+
645+ // Untrusted key: used to create messages with invalid signatures
646+ untrustedKey , err := crypto .GenerateKey ()
647+ Require (t , err )
648+ untrustedSigner := signature .DataSignerFromPrivateKey (untrustedKey )
649+
650+ feedErrChan := make (chan error , 10 )
651+ trustedBroadcaster := broadcaster .NewBroadcaster (func () * wsbroadcastserver.BroadcasterConfig { return & wsbroadcastserver .DefaultTestBroadcasterConfig }, chainId , feedErrChan , trustedSigner )
652+
653+ Require (t , trustedBroadcaster .Initialize ())
654+ Require (t , trustedBroadcaster .Start (ctx ))
655+ defer trustedBroadcaster .StopAndWait ()
656+
657+ // Second broadcaster (not started) used only to create messages signed with the untrusted key
658+ untrustedBroadcaster := broadcaster .NewBroadcaster (func () * wsbroadcastserver.BroadcasterConfig { return & wsbroadcastserver .DefaultTestBroadcasterConfig }, chainId , make (chan error , 1 ), untrustedSigner )
659+
660+ ts := & accumulatingTransactionStreamer {}
661+
662+ clientFeedErrChan := make (chan error , 10 )
663+ broadcastClient , err := newTestBroadcastClient (
664+ DefaultTestConfig ,
665+ trustedBroadcaster .ListenerAddr (),
666+ chainId ,
667+ 0 ,
668+ ts ,
669+ nil ,
670+ clientFeedErrChan ,
671+ & trustedAddr ,
672+ t ,
673+ )
674+ Require (t , err )
675+ broadcastClient .Start (ctx )
676+ defer broadcastClient .StopAndWait ()
677+
678+ // Batch 1: valid messages (seq 0, 1) - should be delivered.
679+ Require (t , trustedBroadcaster .BroadcastFeedMessages (feedMessage (t , trustedBroadcaster , 0 )))
680+ Require (t , trustedBroadcaster .BroadcastFeedMessages (feedMessage (t , trustedBroadcaster , 1 )))
681+ ts .awaitCount (t , 2 , 10 * time .Second )
682+
683+ // Batch 2: invalid messages (seq 2, 3) signed with untrusted key - should be skipped.
684+ Require (t , trustedBroadcaster .BroadcastFeedMessages (feedMessage (t , untrustedBroadcaster , 2 )))
685+ Require (t , trustedBroadcaster .BroadcastFeedMessages (feedMessage (t , untrustedBroadcaster , 3 )))
686+
687+ // Sentinel (seq 2): a valid message that deterministically proves the client has
688+ // processed and skipped the invalid messages. WebSocket messages are ordered, so the
689+ // sentinel can only arrive after the invalid ones have been processed (and skipped).
690+ // Invalid messages don't advance nextSeqNum (stays at 2), so the sentinel is seq 2.
691+ Require (t , trustedBroadcaster .BroadcastFeedMessages (feedMessage (t , trustedBroadcaster , 2 )))
692+ ts .awaitCount (t , 3 , 10 * time .Second )
693+
694+ // Verify: only valid messages were delivered, and all have trusted signatures.
695+ got := ts .getMessages ()
696+ if len (got ) != 3 {
697+ t .Fatalf ("expected 3 messages, got %d" , len (got ))
698+ }
699+ for i , msg := range got {
700+ if msg .SequenceNumber != arbutil .MessageIndex (i ) { // nolint: gosec
701+ t .Fatalf ("message %d: unexpected seq number: %d" , i , msg .SequenceNumber )
702+ }
703+ hash := msg .SignatureHash (chainId )
704+ sigPub , err := crypto .SigToPub (hash .Bytes (), msg .Signature )
705+ Require (t , err )
706+ signerAddr := crypto .PubkeyToAddress (* sigPub )
707+ if signerAddr != trustedAddr {
708+ t .Fatalf ("message %d (seq %d): signed by %s, expected trusted signer %s" , i , msg .SequenceNumber , signerAddr , trustedAddr )
709+ }
710+ }
711+
712+ // Verify no fatal errors occurred (invalid signatures are non-fatal since NIT-4017)
713+ select {
714+ case err := <- clientFeedErrChan :
715+ t .Fatalf ("unexpected fatal feed error: %v" , err )
716+ default :
717+ }
718+ }
719+
595720func TestBroadcastClientReconnectsOnServerDisconnect (t * testing.T ) {
596721 t .Parallel ()
597722 ctx , cancel := context .WithCancel (context .Background ())
0 commit comments