Skip to content

Commit 27b16db

Browse files
committed
switched to RWMutex to ensure switch to fallback on the test stop
1 parent 02ceb76 commit 27b16db

File tree

4 files changed

+63
-22
lines changed

4 files changed

+63
-22
lines changed

internal/common/convert.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,17 @@ func QueryResultTypePtr(t s.QueryResultType) *s.QueryResultType {
9595
func PtrOf[T any](v T) *T {
9696
return &v
9797
}
98+
99+
// ValueFromPtr returns the value from a pointer.
100+
func ValueFromPtr[T any](v *T) T {
101+
if v == nil {
102+
return Zero[T]()
103+
}
104+
return *v
105+
}
106+
107+
// Zero returns the zero value of a type by return type.
108+
func Zero[T any]() T {
109+
var zero T
110+
return zero
111+
}

internal/common/convert_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,19 @@ func TestCeilHelpers(t *testing.T) {
5555
assert.Equal(t, int32(2), Int32Ceil(1.1))
5656
assert.Equal(t, int64(2), Int64Ceil(1.1))
5757
}
58+
59+
func TestValueFromPtr(t *testing.T) {
60+
assert.Equal(t, "a", ValueFromPtr(PtrOf("a")))
61+
assert.Equal(t, 1, ValueFromPtr(PtrOf(1)))
62+
assert.Equal(t, int32(1), ValueFromPtr(PtrOf(int32(1))))
63+
assert.Equal(t, int64(1), ValueFromPtr(PtrOf(int64(1))))
64+
assert.Equal(t, 1.1, ValueFromPtr(PtrOf(1.1)))
65+
assert.Equal(t, true, ValueFromPtr(PtrOf(true)))
66+
assert.Equal(t, []string{"a"}, ValueFromPtr(PtrOf([]string{"a"})))
67+
}
68+
69+
func TestZero(t *testing.T) {
70+
assert.Equal(t, "", Zero[string]())
71+
assert.Equal(t, 0, Zero[int]())
72+
assert.Equal(t, (*int)(nil), Zero[*int]())
73+
}

internal/common/testlogger/testlogger.go

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,12 @@ package testlogger
2222

2323
import (
2424
"fmt"
25+
"go.uber.org/cadence/internal/common"
2526
"slices"
2627
"strings"
2728
"sync"
2829

2930
"github.com/stretchr/testify/require"
30-
"go.uber.org/atomic"
3131
"go.uber.org/zap"
3232
"go.uber.org/zap/zapcore"
3333
"go.uber.org/zap/zaptest"
@@ -58,10 +58,11 @@ func NewZap(t TestingT) *zap.Logger {
5858
logAfterComplete, err := zap.NewDevelopment()
5959
require.NoError(t, err, "could not build a fallback zap logger")
6060
replaced := &fallbackTestCore{
61+
mu: &sync.RWMutex{},
6162
t: t,
6263
fallback: logAfterComplete.Core(),
6364
testing: zaptest.NewLogger(t).Core(),
64-
completed: &atomic.Bool{},
65+
completed: common.PtrOf(false),
6566
}
6667

6768
t.Cleanup(replaced.UseFallback) // switch to fallback before ending the test
@@ -81,30 +82,38 @@ func NewObserved(t TestingT) (*zap.Logger, *observer.ObservedLogs) {
8182
}
8283

8384
type fallbackTestCore struct {
84-
sync.Mutex
85+
mu *sync.RWMutex
8586
t TestingT
8687
fallback zapcore.Core
8788
testing zapcore.Core
88-
completed *atomic.Bool
89+
completed *bool
8990
}
9091

9192
var _ zapcore.Core = (*fallbackTestCore)(nil)
9293

9394
func (f *fallbackTestCore) UseFallback() {
94-
f.completed.Store(true)
95+
f.mu.Lock()
96+
defer f.mu.Unlock()
97+
*f.completed = true
9598
}
9699

97100
func (f *fallbackTestCore) Enabled(level zapcore.Level) bool {
98-
if f.completed.Load() {
101+
f.mu.RLock()
102+
defer f.mu.RUnlock()
103+
if f.completed != nil && *f.completed {
99104
return f.fallback.Enabled(level)
100105
}
101106
return f.testing.Enabled(level)
102107
}
103108

104109
func (f *fallbackTestCore) With(fields []zapcore.Field) zapcore.Core {
110+
f.mu.Lock()
111+
defer f.mu.Unlock()
112+
105113
// need to copy and defer, else the returned core will be used at an
106114
// arbitrarily later point in time, possibly after the test has completed.
107115
return &fallbackTestCore{
116+
mu: f.mu,
108117
t: f.t,
109118
fallback: f.fallback.With(fields),
110119
testing: f.testing.With(fields),
@@ -113,6 +122,8 @@ func (f *fallbackTestCore) With(fields []zapcore.Field) zapcore.Core {
113122
}
114123

115124
func (f *fallbackTestCore) Check(entry zapcore.Entry, checked *zapcore.CheckedEntry) *zapcore.CheckedEntry {
125+
f.mu.RLock()
126+
defer f.mu.RUnlock()
116127
// see other Check impls, all look similar.
117128
// this defers the "where to log" decision to Write, as `f` is the core that will write.
118129
if f.fallback.Enabled(entry.Level) {
@@ -122,7 +133,10 @@ func (f *fallbackTestCore) Check(entry zapcore.Entry, checked *zapcore.CheckedEn
122133
}
123134

124135
func (f *fallbackTestCore) Write(entry zapcore.Entry, fields []zapcore.Field) error {
125-
if f.completed.Load() {
136+
f.mu.RLock()
137+
defer f.mu.RUnlock()
138+
139+
if common.ValueFromPtr(f.completed) {
126140
entry.Message = fmt.Sprintf("COULD FAIL TEST %q, logged too late: %v", f.t.Name(), entry.Message)
127141

128142
hasStack := slices.ContainsFunc(fields, func(field zapcore.Field) bool {
@@ -134,14 +148,14 @@ func (f *fallbackTestCore) Write(entry zapcore.Entry, fields []zapcore.Field) er
134148
}
135149
return f.fallback.Write(entry, fields)
136150
}
137-
// Ensure no concurrent writes to the test logger.
138-
f.Lock()
139-
defer f.Unlock()
140151
return f.testing.Write(entry, fields)
141152
}
142153

143154
func (f *fallbackTestCore) Sync() error {
144-
if f.completed.Load() {
155+
f.mu.RLock()
156+
defer f.mu.RUnlock()
157+
158+
if common.ValueFromPtr(f.completed) {
145159
return f.fallback.Sync()
146160
}
147161
return f.testing.Sync()

internal/common/testlogger/testlogger_test.go

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,16 @@ package testlogger
2222

2323
import (
2424
"fmt"
25+
"go.uber.org/cadence/internal/common"
2526
"os"
27+
"sync"
2628
"testing"
2729
"time"
2830

2931
"go.uber.org/zap/zaptest"
3032

3133
"github.com/stretchr/testify/assert"
3234
"github.com/stretchr/testify/require"
33-
"go.uber.org/atomic"
3435
"go.uber.org/zap"
3536
)
3637

@@ -47,7 +48,7 @@ func TestMain(m *testing.M) {
4748
select {
4849
case <-logged:
4950
os.Exit(code)
50-
case <-time.After(time.Second): // should be MUCH faster
51+
case <-time.After(time.Millisecond): // should be MUCH faster
5152
_, _ = fmt.Fprintln(os.Stderr, "timed out waiting for test to log")
5253
os.Exit(1)
5354
}
@@ -131,10 +132,11 @@ func TestFallbackTestCore_Enabled(t *testing.T) {
131132
require.NoError(t, err)
132133

133134
core := &fallbackTestCore{
135+
mu: &sync.RWMutex{},
134136
t: t,
135137
fallback: fallbackLogger.Core(),
136138
testing: zaptest.NewLogger(t).Core(),
137-
completed: &atomic.Bool{},
139+
completed: common.PtrOf(false),
138140
}
139141
// Debug is enabled in zaptest.Logger
140142
assert.True(t, core.Enabled(zap.DebugLevel))
@@ -144,16 +146,11 @@ func TestFallbackTestCore_Enabled(t *testing.T) {
144146
}
145147

146148
func TestFallbackTestCore_Sync(t *testing.T) {
147-
148-
core := &fallbackTestCore{
149-
t: t,
150-
fallback: zaptest.NewLogger(t).Core(),
151-
testing: zaptest.NewLogger(t).Core(),
152-
completed: &atomic.Bool{},
153-
}
149+
core := NewZap(t).Core().(*fallbackTestCore)
150+
core.fallback = zap.NewNop().Core()
154151
// Sync for testing logger must not fail.
155152
assert.NoError(t, core.Sync(), "normal sync must not fail")
156153
core.UseFallback()
157154
// Sync for fallback logger must not fail.
158-
assert.NoError(t, core.Sync(), "fallback sync must not fail")
155+
assert.NoError(t, core.Sync(), "fallback sync must not fail")
159156
}

0 commit comments

Comments
 (0)