Skip to content

Commit 1797994

Browse files
committed
add support for comparing values in assert.Asserter with IsEqual method
1 parent a24ea1d commit 1797994

File tree

3 files changed

+208
-34
lines changed

3 files changed

+208
-34
lines changed

assert/Asserter.go

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ package assert
33
import (
44
"errors"
55
"fmt"
6-
"github.com/adamluzsi/testcase/internal"
76
"reflect"
87
"strings"
98
"testing"
109

10+
"github.com/adamluzsi/testcase/internal"
11+
1112
"github.com/adamluzsi/testcase/internal/fmterror"
1213
)
1314

@@ -149,6 +150,14 @@ func (a Asserter) NotPanic(blk func(), msg ...interface{}) {
149150
})
150151
}
151152

153+
type equalable[T any] interface {
154+
IsEqual(oth T) bool
155+
}
156+
157+
type equalableWithError[T any] interface {
158+
IsEqual(oth T) (bool, error)
159+
}
160+
152161
func (a Asserter) Equal(expected, actually interface{}, msg ...interface{}) {
153162
a.TB.Helper()
154163
if a.eq(expected, actually) {
@@ -193,9 +202,50 @@ func (a Asserter) NotEqual(v, oth interface{}, msg ...interface{}) {
193202
}
194203

195204
func (a Asserter) eq(exp, act interface{}) bool {
205+
if isEqual, ok := a.tryIsEqual(exp, act); ok {
206+
return isEqual
207+
}
208+
196209
return reflect.DeepEqual(exp, act)
197210
}
198211

212+
func (a Asserter) tryIsEqual(exp, act interface{}) (isEqual bool, ok bool) {
213+
defer func() { recover() }()
214+
expRV := reflect.ValueOf(exp)
215+
actRV := reflect.ValueOf(act)
216+
217+
if expRV.Type() != actRV.Type() {
218+
return false, false
219+
}
220+
221+
method := expRV.MethodByName("IsEqual")
222+
methodType := method.Type()
223+
224+
if methodType.NumIn() != 1 {
225+
return false, false
226+
}
227+
if numOut := methodType.NumOut(); !(numOut == 1 || numOut == 2) {
228+
return false, false
229+
}
230+
if methodType.In(0) != actRV.Type() {
231+
return false, false
232+
}
233+
234+
res := method.Call([]reflect.Value{actRV})
235+
236+
switch {
237+
case methodType.NumOut() == 1: // IsEqual(T) (bool)
238+
return res[0].Bool(), true
239+
240+
case methodType.NumOut() == 2: // IsEqual(T) (bool, error)
241+
Must(a.TB).Nil(res[1].Interface())
242+
return res[0].Bool(), true
243+
244+
default:
245+
return false, false
246+
}
247+
}
248+
199249
func (a Asserter) Contain(src, has interface{}, msg ...interface{}) {
200250
a.TB.Helper()
201251
rSrc := reflect.ValueOf(src)
@@ -575,55 +625,49 @@ func (a Asserter) AnyOf(blk func(a *AnyOf), msg ...interface{}) {
575625
blk(anyOf)
576626
}
577627

578-
// Empty gets whether the specified value is considered empty.
579-
func (a Asserter) Empty(v interface{}, msg ...interface{}) {
580-
a.TB.Helper()
581-
582-
fail := func() {
583-
a.Fn(fmterror.Message{
584-
Method: "Empty",
585-
Cause: "Value was expected to be empty.",
586-
Values: []fmterror.Value{
587-
{Label: "value", Value: v},
588-
},
589-
UserMessage: msg,
590-
})
591-
}
592-
628+
func (a Asserter) isEmpty(v any) bool {
593629
if v == nil {
594-
return
630+
return true
595631
}
596632
rv := reflect.ValueOf(v)
597633
switch rv.Kind() {
598634
case reflect.Chan, reflect.Map, reflect.Slice:
599-
if rv.Len() != 0 {
600-
fail()
601-
}
635+
return rv.Len() == 0
636+
602637
case reflect.Array:
603638
zero := reflect.New(rv.Type()).Elem().Interface()
604-
if !a.eq(zero, v) {
605-
fail()
606-
}
639+
return a.eq(zero, v)
607640

608641
case reflect.Ptr:
609642
if rv.IsNil() {
610-
return
611-
}
612-
if !a.try(func(a Asserter) { a.Empty(rv.Elem().Interface()) }) {
613-
fail()
643+
return true
614644
}
645+
return a.isEmpty(rv.Elem().Interface())
615646

616647
default:
617-
if !a.eq(reflect.Zero(rv.Type()).Interface(), v) {
618-
fail()
619-
}
648+
return a.eq(reflect.Zero(rv.Type()).Interface(), v)
620649
}
621650
}
622651

652+
// Empty gets whether the specified value is considered empty.
653+
func (a Asserter) Empty(v interface{}, msg ...interface{}) {
654+
a.TB.Helper()
655+
if a.isEmpty(v) {
656+
return
657+
}
658+
a.Fn(fmterror.Message{
659+
Method: "Empty",
660+
Cause: "Value was expected to be empty.",
661+
Values: []fmterror.Value{
662+
{Label: "value", Value: v},
663+
},
664+
UserMessage: msg,
665+
})
666+
}
667+
623668
// NotEmpty gets whether the specified value is considered empty.
624669
func (a Asserter) NotEmpty(v interface{}, msg ...interface{}) {
625670
a.TB.Helper()
626-
627671
if !a.try(func(a Asserter) { a.Empty(v) }) {
628672
return
629673
}
@@ -639,7 +683,7 @@ func (a Asserter) NotEmpty(v interface{}, msg ...interface{}) {
639683

640684
// ErrorIs allows you to assert an error value by an expectation.
641685
// if the implementation of the test subject later changes, and for example, it starts to use wrapping,
642-
// this should not be an issue as the err's error chain is also matched against the expectation.
686+
// this should not be an issue as the IsEqualErr's error chain is also matched against the expectation.
643687
func (a Asserter) ErrorIs(expected, actual error, msg ...interface{}) {
644688
a.TB.Helper()
645689

assert/Asserter_test.go

Lines changed: 80 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@ package assert_test
33
import (
44
"errors"
55
"fmt"
6-
"github.com/adamluzsi/testcase"
7-
"github.com/adamluzsi/testcase/internal"
86
"reflect"
97
"strings"
108
"testing"
119

10+
"github.com/adamluzsi/testcase"
11+
"github.com/adamluzsi/testcase/internal"
12+
1213
"github.com/adamluzsi/testcase/assert"
1314
"github.com/adamluzsi/testcase/random"
1415
)
@@ -314,6 +315,54 @@ func TestAsserter_Equal(t *testing.T) {
314315
Actual: []byte("foo"),
315316
IsFailed: true,
316317
},
318+
{
319+
Desc: "when value implements equalable and the two value is equal by IsEqual",
320+
Expected: ExampleEqualable{
321+
relevantUnexportedValue: 42,
322+
IrrelevantExportedField: 42,
323+
},
324+
Actual: ExampleEqualable{
325+
relevantUnexportedValue: 42,
326+
IrrelevantExportedField: 24,
327+
},
328+
IsFailed: false,
329+
},
330+
{
331+
Desc: "when value implements equalable and the two value is not equal by IsEqual",
332+
Expected: ExampleEqualable{
333+
relevantUnexportedValue: 24,
334+
IrrelevantExportedField: 42,
335+
},
336+
Actual: ExampleEqualable{
337+
relevantUnexportedValue: 42,
338+
IrrelevantExportedField: 42,
339+
},
340+
IsFailed: true,
341+
},
342+
{
343+
Desc: "when value implements equalableWithError and the two value is equal by IsEqual",
344+
Expected: ExampleEqualableWithError{
345+
relevantUnexportedValue: 42,
346+
IrrelevantExportedField: 42,
347+
},
348+
Actual: ExampleEqualableWithError{
349+
relevantUnexportedValue: 42,
350+
IrrelevantExportedField: 4242,
351+
},
352+
IsFailed: false,
353+
},
354+
{
355+
Desc: "when value implements equalableWithError and the two value is not equal by IsEqual",
356+
Expected: ExampleEqualableWithError{
357+
relevantUnexportedValue: 42,
358+
IrrelevantExportedField: 42,
359+
},
360+
Actual: ExampleEqualableWithError{
361+
relevantUnexportedValue: 4242,
362+
IrrelevantExportedField: 42,
363+
},
364+
IsFailed: true,
365+
},
317366
//{
318367
// Desc: "when equal function provided",
319368
// Expected: fn1,
@@ -352,6 +401,35 @@ func TestAsserter_Equal(t *testing.T) {
352401
}
353402
}
354403

404+
func TestAsserter_Equal_equalableWithError_ErrorReturned(t *testing.T) {
405+
t.Log("when value implements equalableWithError and IsEqual returns an error")
406+
407+
expected := ExampleEqualableWithError{
408+
relevantUnexportedValue: 42,
409+
IrrelevantExportedField: 42,
410+
IsEqualErr: errors.New("boom"),
411+
}
412+
413+
actual := ExampleEqualableWithError{
414+
relevantUnexportedValue: 42,
415+
IrrelevantExportedField: 42,
416+
}
417+
418+
stub := &testcase.StubTB{}
419+
420+
internal.Recover(func() {
421+
a := assert.Asserter{
422+
TB: stub,
423+
Fn: stub.Fatal,
424+
}
425+
426+
a.Equal(expected, actual)
427+
})
428+
if !stub.IsFailed {
429+
t.Fatal("expected that testing.TB is failed because the returned error")
430+
}
431+
}
432+
355433
func TestAsserter_NotEqual(t *testing.T) {
356434
type TestCase struct {
357435
Desc string

assert/example_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,55 @@ func ExampleAsserter_ErrorIs() {
248248
assert.Must(tb).ErrorIs(errors.New("boom"), actualErr) // passes for equality
249249
assert.Must(tb).ErrorIs(errors.New("boom"), fmt.Errorf("wrapped error: %w", actualErr)) // passes for wrapped errors
250250
}
251+
252+
type ExampleEqualable struct {
253+
IrrelevantExportedField int
254+
relevantUnexportedValue int
255+
}
256+
257+
func (es ExampleEqualable) IsEqual(oth ExampleEqualable) bool {
258+
return es.relevantUnexportedValue == oth.relevantUnexportedValue
259+
}
260+
261+
func ExampleAsserter_Equal_isEqualFunctionUsedForComparison() {
262+
var tb testing.TB
263+
264+
expected := ExampleEqualable{
265+
IrrelevantExportedField: 42,
266+
relevantUnexportedValue: 24,
267+
}
268+
269+
actual := ExampleEqualable{
270+
IrrelevantExportedField: 4242,
271+
relevantUnexportedValue: 24,
272+
}
273+
274+
assert.Must(tb).Equal(expected, actual) // passes as by IsEqual terms the two value is equal
275+
}
276+
277+
type ExampleEqualableWithError struct {
278+
IrrelevantExportedField int
279+
relevantUnexportedValue int
280+
IsEqualErr error
281+
}
282+
283+
func (es ExampleEqualableWithError) IsEqual(oth ExampleEqualableWithError) (bool, error) {
284+
return es.relevantUnexportedValue == oth.relevantUnexportedValue, es.IsEqualErr
285+
}
286+
287+
func ExampleAsserter_Equal_isEqualFunctionThatSupportsErrorReturning() {
288+
var tb testing.TB
289+
290+
expected := ExampleEqualableWithError{
291+
IrrelevantExportedField: 42,
292+
relevantUnexportedValue: 24,
293+
IsEqualErr: errors.New("sadly something went wrong"),
294+
}
295+
296+
actual := ExampleEqualableWithError{
297+
IrrelevantExportedField: 42,
298+
relevantUnexportedValue: 24,
299+
}
300+
301+
assert.Must(tb).Equal(expected, actual) // fails because the error returned from the IsEqual function.
302+
}

0 commit comments

Comments
 (0)