Skip to content

Commit 1867e61

Browse files
committed
Allow tests that inspect the encoding of a wrapped value without unwrapping it.
1 parent ba2722f commit 1867e61

File tree

7 files changed

+78
-19
lines changed

7 files changed

+78
-19
lines changed

enginetest/enginetests.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ func TestBrokenQueries(t *testing.T, harness Harness) {
263263
// queries during debugging.
264264
func RunQueryTests(t *testing.T, harness Harness, queries []queries.QueryTest) {
265265
for _, tt := range queries {
266-
TestQuery(t, harness, tt.Query, tt.Expected, tt.ExpectedColumns, nil)
266+
testQuery(t, harness, tt.Query, tt.Expected, tt.ExpectedColumns, nil, !tt.DontUnwrap)
267267
}
268268
}
269269

enginetest/evaluation.go

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,10 @@ func TestTransactionScriptWithEngine(t *testing.T, e QueryEngine, harness Harnes
326326
// TestQuery runs a query on the engine given and asserts that results are as expected.
327327
// TODO: this should take en engine
328328
func TestQuery(t *testing.T, harness Harness, q string, expected []sql.Row, expectedCols []*sql.Column, bindings map[string]sqlparser.Expr) {
329+
testQuery(t, harness, q, expected, expectedCols, bindings, true)
330+
}
331+
332+
func testQuery(t *testing.T, harness Harness, q string, expected []sql.Row, expectedCols []*sql.Column, bindings map[string]sqlparser.Expr, unwrapValues bool) {
329333
t.Run(q, func(t *testing.T) {
330334
if sh, ok := harness.(SkippingHarness); ok {
331335
if sh.SkipQueryTest(q) {
@@ -336,7 +340,7 @@ func TestQuery(t *testing.T, harness Harness, q string, expected []sql.Row, expe
336340
e := mustNewEngine(t, harness)
337341
defer e.Close()
338342
ctx := NewContext(harness)
339-
TestQueryWithContext(t, ctx, e, harness, q, expected, expectedCols, bindings, nil)
343+
testQueryWithContext(t, ctx, e, harness, q, expected, expectedCols, bindings, nil, unwrapValues)
340344
})
341345
}
342346

@@ -378,6 +382,21 @@ func TestQueryWithContext(
378382
expectedCols []*sql.Column,
379383
bindings map[string]sqlparser.Expr,
380384
qFlags *sql.QueryFlags,
385+
) {
386+
testQueryWithContext(t, ctx, e, harness, q, expected, expectedCols, bindings, qFlags, true)
387+
}
388+
389+
func testQueryWithContext(
390+
t *testing.T,
391+
ctx *sql.Context,
392+
e QueryEngine,
393+
harness Harness,
394+
q string,
395+
expected []sql.Row,
396+
expectedCols []*sql.Column,
397+
bindings map[string]sqlparser.Expr,
398+
qFlags *sql.QueryFlags,
399+
unwrapValues bool,
381400
) {
382401
ctx = ctx.WithQuery(q)
383402
require := require.New(t)
@@ -393,7 +412,7 @@ func TestQueryWithContext(
393412
require.NoError(err, "Unexpected error for query %s: %s", q, err)
394413

395414
if expected != nil {
396-
CheckResults(t, harness, expected, expectedCols, sch, rows, q, e)
415+
checkResults(t, harness, expected, expectedCols, sch, rows, q, e, unwrapValues)
397416
}
398417

399418
require.Equal(
@@ -502,7 +521,6 @@ func TestPreparedQueryWithContext(t *testing.T, ctx *sql.Context, e QueryEngine,
502521
validateEngine(t, ctx, h, e)
503522
}
504523

505-
// CheckResults compares the
506524
func CheckResults(
507525
t *testing.T,
508526
h Harness,
@@ -512,11 +530,25 @@ func CheckResults(
512530
rows []sql.Row,
513531
q string,
514532
e QueryEngine,
533+
) {
534+
checkResults(t, h, expected, expectedCols, sch, rows, q, e, true)
535+
}
536+
537+
func checkResults(
538+
t *testing.T,
539+
h Harness,
540+
expected []sql.Row,
541+
expectedCols []*sql.Column,
542+
sch sql.Schema,
543+
rows []sql.Row,
544+
q string,
545+
e QueryEngine,
546+
unwrapValues bool,
515547
) {
516548
if reh, ok := h.(ResultEvaluationHarness); ok {
517-
reh.EvaluateQueryResults(t, expected, expectedCols, sch, rows, q)
549+
reh.EvaluateQueryResults(t, expected, expectedCols, sch, rows, q, unwrapValues)
518550
} else {
519-
checkResults(t, expected, expectedCols, sch, rows, q, e)
551+
checkResultsDefault(t, expected, expectedCols, sch, rows, q, e, unwrapValues)
520552
}
521553
}
522554

@@ -669,16 +701,19 @@ func toSQL(c *sql.Column, expected any, isZeroTime bool) (any, error) {
669701
}
670702
}
671703

672-
// checkResults is the default implementation for checking the results of a test query assertion for harnesses that
704+
// checkResultsDefault is the default implementation for checking the results of a test query assertion for harnesses that
673705
// don't implement ResultEvaluationHarness. All numerical values are widened to their widest type before comparison.
674-
func checkResults(
706+
// Based on the value of |unwrapValues|, this either normalized wrapped values by unwrapping them, or replaces them
707+
// with their hash so the test caller can assert that the values are wrapped and have a certain hash.
708+
func checkResultsDefault(
675709
t *testing.T,
676710
expected []sql.Row,
677711
expectedCols []*sql.Column,
678712
sch sql.Schema,
679713
rows []sql.Row,
680714
q string,
681715
e QueryEngine,
716+
unwrapValues bool,
682717
) {
683718
widenedRows := WidenRows(t, sch, rows)
684719
widenedExpected := WidenRows(t, sch, expected)
@@ -715,6 +750,14 @@ func checkResults(
715750
widenedRow[i] = el
716751
}
717752
}
753+
case sql.AnyWrapper:
754+
if unwrapValues {
755+
var err error
756+
widenedRow[i], err = sql.UnwrapAny(context.Background(), v)
757+
require.NoError(t, err)
758+
} else {
759+
widenedRow[i] = v.Hash()
760+
}
718761
}
719762
}
720763
}
@@ -825,7 +868,6 @@ func WidenRow(t *testing.T, sch sql.Schema, row sql.Row) sql.Row {
825868

826869
// widenValue normalizes the input by widening it to its widest type and unwrapping any wrappers.
827870
func widenValue(t *testing.T, v interface{}) (vw interface{}) {
828-
var err error
829871
switch x := v.(type) {
830872
case int:
831873
vw = int64(x)
@@ -850,10 +892,6 @@ func widenValue(t *testing.T, v interface{}) (vw interface{}) {
850892
// The exact expected decimal type value cannot be defined in enginetests,
851893
// so convert the result to string format, which is the value we get on sql shell.
852894
vw = x.StringFixed(x.Exponent() * -1)
853-
case sql.AnyWrapper:
854-
vw, err = x.UnwrapAny(context.Background())
855-
vw = widenValue(t, vw)
856-
require.NoError(t, err)
857895
default:
858896
vw = v
859897
}
@@ -1160,12 +1198,12 @@ func RunWriteQueryTestWithEngine(t *testing.T, harness Harness, e QueryEngine, t
11601198
}
11611199

11621200
ctx := NewContext(harness)
1163-
TestQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil, nil)
1201+
testQueryWithContext(t, ctx, e, harness, tt.WriteQuery, tt.ExpectedWriteResult, nil, nil, nil, !tt.DontUnwrap)
11641202
expectedSelect := tt.ExpectedSelect
11651203
if IsServerEngine(e) && tt.SkipServerEngine {
11661204
expectedSelect = nil
11671205
}
1168-
TestQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil, nil)
1206+
testQueryWithContext(t, ctx, e, harness, tt.SelectQuery, expectedSelect, nil, nil, nil, !tt.DontUnwrap)
11691207
}
11701208

11711209
func supportedDialect(harness Harness, dialect string) bool {

enginetest/harness.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ type ResultEvaluationHarness interface {
162162
expectdSch sql.Schema,
163163
actualRows []sql.Row,
164164
query string,
165+
unwrapValues bool,
165166
)
166167

167168
// EvaluateExpectedError compares expected error strings to actual errors and emits failed test assertions in the

enginetest/queries/queries.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ type QueryTest struct {
4747
// Dialect is the supported dialect for this query, which must match the dialect of the harness if specified.
4848
// The query is skipped if the dialect doesn't match.
4949
Dialect string
50+
// DontUnwrap indicates whether to skip normalizing the select results via unwrapping wrapped values.
51+
// Instead, the test engine will replace wrapped values with their hash (as determined by sql.AnyWrapped.Hash).
52+
// Set this to test the exact encodings being returned by the query.
53+
DontUnwrap bool
5054
}
5155

5256
type QueryPlanTest struct {
@@ -11382,6 +11386,10 @@ type WriteQueryTest struct {
1138211386
// Dialect is the supported dialect for this test, which must match the dialect of the harness if specified.
1138311387
// The script is skipped if the dialect doesn't match.
1138411388
Dialect string
11389+
// DontUnwrap indicates whether to skip normalizing the select results via unwrapping wrapped values.
11390+
// Instead, the test engine will replace wrapped values with their hash (as determined by sql.AnyWrapped.Hash).
11391+
// Set this to test the exact encodings being returned by the query.
11392+
DontUnwrap bool
1138511393
}
1138611394

1138711395
// GenericErrorQueryTest is a query test that is used to assert an error occurs for some query, without specifying what

enginetest/wrapper_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ func (w ErrorWrapper[T]) IsExactLength() bool {
6666
return w.isExactLength
6767
}
6868

69+
func (w ErrorWrapper[T]) Hash() interface{} {
70+
return nil
71+
}
72+
6973
func setupWrapperTests(t *testing.T) (*sql.Context, *memory.Database, *MemoryHarness, *sqle.Engine) {
7074
db := memory.NewDatabase("mydb")
7175
pro := memory.NewDBProvider(db)

memory/wrapper_test.go

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ func (w SimpleWrapper[T]) IsExactLength() bool {
9797
return w.isExactLength
9898
}
9999

100+
func (w SimpleWrapper[T]) Hash() interface{} {
101+
return nil
102+
}
103+
100104
// ErrorWrapper is a wrapped type that errors when unwrapped. This can be used to test that certain operations
101105
// won't trigger an unwrap.
102106
type ErrorWrapper[T any] struct {
@@ -111,10 +115,6 @@ func (w ErrorWrapper[T]) Compare(ctx context.Context, other interface{}) (cmp in
111115
var textErrorWrapper = ErrorWrapper[string]{maxByteLength: types.Text.MaxByteLength(), isExactLength: false}
112116
var longTextErrorWrapper = ErrorWrapper[string]{maxByteLength: types.LongText.MaxByteLength(), isExactLength: false}
113117

114-
func exactLengthErrorWrapper(maxByteLength int64) ErrorWrapper[string] {
115-
return ErrorWrapper[string]{maxByteLength: maxByteLength, isExactLength: true}
116-
}
117-
118118
func (w ErrorWrapper[T]) assertInterfaces() {
119119
var _ sql.Wrapper[T] = w
120120
}
@@ -135,6 +135,10 @@ func (w ErrorWrapper[T]) IsExactLength() bool {
135135
return w.isExactLength
136136
}
137137

138+
func (w ErrorWrapper[T]) Hash() interface{} {
139+
return nil
140+
}
141+
138142
// TestWrapperCompare tests that a wrapped value can be used in comparisons.
139143
func TestWrapperCompare(t *testing.T) {
140144
db := memory.NewDatabase("db")

sql/wrapper.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ type AnyWrapper interface {
4646
// Setting |comparable| to true means that the wrapper was able to compare them and store the result in |cmp|.
4747
// Setting |comparable| to false means that the wrapper wasn't able to compare them.
4848
Compare(ctx context.Context, other interface{}) (cmp int, comparable bool, err error)
49+
50+
// Hash is a value that can be compared to check if two wrapper values are equal. Equality of the hashes implies
51+
// equality of the wrappers.
52+
Hash() interface{}
4953
}
5054

5155
// Wrapper is an interface for types that encapsulate a SQL value of a specific type.

0 commit comments

Comments
 (0)