@@ -10,6 +10,7 @@ import (
1010 "math/big"
1111 "sync"
1212 "testing"
13+ "testing/synctest"
1314 "time"
1415
1516 "github.com/ethereum/go-ethereum/common"
@@ -95,66 +96,66 @@ func TestAgent(t *testing.T) {
9596
9697 for _ , tc := range tests {
9798 t .Run (tc .name , func (t * testing.T ) {
98- t .Parallel ()
99+ synctest .Test (t , func (t * testing.T ) {
100+ wait := make (chan struct {}, 1 )
101+ addr := swarm .RandAddress (t )
102+
103+ backend := & mockchainBackend {
104+ limit : tc .limit ,
105+ limitCallback : func () {
106+ wait <- struct {}{}
107+ },
108+ incrementBy : tc .incrementBy ,
109+ block : tc .blocksPerRound ,
110+ balance : tc .balance ,
111+ }
99112
100- wait := make (chan struct {})
101- addr := swarm .RandAddress (t )
113+ var radius uint8 = 8
102114
103- backend := & mockchainBackend {
104- limit : tc .limit ,
105- limitCallback : func () {
106- select {
107- case wait <- struct {}{}:
108- default :
109- }
110- },
111- incrementBy : tc .incrementBy ,
112- block : tc .blocksPerRound ,
113- balance : tc .balance ,
114- }
115+ contract := & mockContract {t : t , expectedRadius : radius + tc .doubling }
115116
116- var radius uint8 = 8
117+ service , _ := createService (t , addr , backend , contract , tc .blocksPerRound , tc .blocksPerPhase , radius , tc .doubling )
118+ testutil .CleanupCloser (t , service )
117119
118- contract := & mockContract { t : t , expectedRadius : radius + tc . doubling }
120+ <- wait
119121
120- service , _ := createService (t , addr , backend , contract , tc .blocksPerRound , tc .blocksPerPhase , radius , tc .doubling )
121- testutil .CleanupCloser (t , service )
122+ synctest .Wait ()
122123
123- <- wait
124+ calls := contract . getCalls ()
124125
125- if ! tc .expectedCalls {
126- if len (contract . callsList ) > 0 {
127- t .Fatal ("got unexpected calls" )
128- } else {
126+ if ! tc .expectedCalls {
127+ if len (calls ) > 0 {
128+ t .Fatal ("got unexpected calls" )
129+ }
129130 return
130131 }
131- }
132132
133- assertOrder := func (t * testing.T , want , got contractCall ) {
134- t .Helper ()
135- if want != got {
136- t .Fatalf ("expected call %s, got %s" , want , got )
133+ if len (calls ) == 0 {
134+ t .Fatal ("expected calls but got none" )
137135 }
138- }
139136
140- contract .mtx .Lock ()
141- defer contract .mtx .Unlock ()
137+ assertOrder := func (t * testing.T , want , got contractCall ) {
138+ t .Helper ()
139+ if want != got {
140+ t .Fatalf ("expected call %s, got %s" , want , got )
141+ }
142+ }
142143
143- prevCall := contract . callsList [0 ]
144+ prevCall := calls [0 ]
144145
145- for i := 1 ; i < len (contract .callsList ); i ++ {
146+ for i := 1 ; i < len (calls ); i ++ {
147+ switch calls [i ] {
148+ case isWinnerCall :
149+ assertOrder (t , revealCall , prevCall )
150+ case revealCall :
151+ assertOrder (t , commitCall , prevCall )
152+ case commitCall :
153+ assertOrder (t , isWinnerCall , prevCall )
154+ }
146155
147- switch contract .callsList [i ] {
148- case isWinnerCall :
149- assertOrder (t , revealCall , prevCall )
150- case revealCall :
151- assertOrder (t , commitCall , prevCall )
152- case commitCall :
153- assertOrder (t , isWinnerCall , prevCall )
156+ prevCall = calls [i ]
154157 }
155-
156- prevCall = contract .callsList [i ]
157- }
158+ })
158159 })
159160 }
160161}
@@ -276,6 +277,18 @@ type mockContract struct {
276277 t * testing.T
277278}
278279
280+ // getCalls returns a snapshot of the calls list
281+ // even after synctest.Wait() all goroutines are blocked, we still should use locking
282+ // for defensive programming.
283+ func (m * mockContract ) getCalls () []contractCall {
284+ m .mtx .Lock ()
285+ defer m .mtx .Unlock ()
286+ // return a copy to avoid external modifications
287+ calls := make ([]contractCall , len (m .callsList ))
288+ copy (calls , m .callsList )
289+ return calls
290+ }
291+
279292func (m * mockContract ) ReserveSalt (context.Context ) ([]byte , error ) {
280293 return nil , nil
281294}
0 commit comments