Skip to content

Commit c9999ae

Browse files
committed
tests: support float approximation in roachtest query comparison utils
Before this change unoptimized query oracle tests would compare results using simple string comparison. However, due to floating point precision limitations, it's possible for results with floating point to diverge during the course of normal computation. This results in test failures that are difficult to reproduce or determine whether they are expected behavior. This change utilizes existing floating point comparison functions used by logic tests to match float values only to a specific precision. Like the logic tests, we also have special handling for floats and decimals under the s390x architecture (see cockroachdb#63244). In order to avoid costly comparisons, we only check floating point precision if the naiive string comparison approach fails and there are float or decimal types in the result. Epic: None Fixes: cockroachdb#95665 Release note: None
1 parent cdb8a43 commit c9999ae

File tree

6 files changed

+251
-18
lines changed

6 files changed

+251
-18
lines changed

pkg/cmd/roachtest/tests/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ go_test(
288288
srcs = [
289289
"blocklist_test.go",
290290
"drt_test.go",
291+
"query_comparison_util_test.go",
291292
"tpcc_test.go",
292293
"util_load_group_test.go",
293294
":mocks_drt", # keep

pkg/cmd/roachtest/tests/costfuzz.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,11 @@ func runCostFuzzQuery(qgen queryGenerator, rnd *rand.Rand, h queryComparisonHelp
114114
return nil
115115
}
116116

117-
if diff := unsortedMatricesDiff(controlRows, perturbRows); diff != "" {
117+
diff, err := unsortedMatricesDiffWithFloatComp(controlRows, perturbRows, h.colTypes)
118+
if err != nil {
119+
return err
120+
}
121+
if diff != "" {
118122
// We have a mismatch in the perturbed vs control query outputs.
119123
h.logStatements()
120124
h.logVerboseOutput()

pkg/cmd/roachtest/tests/query_comparison_util.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ type queryComparisonHelper struct {
428428

429429
statements []string
430430
statementsAndExplains []sqlAndOutput
431+
colTypes []string
431432
}
432433

433434
// runQuery runs the given query and returns the output. If the stmt doesn't
@@ -452,6 +453,14 @@ func (h *queryComparisonHelper) runQuery(stmt string) ([][]string, error) {
452453
return nil, err
453454
}
454455
defer rows.Close()
456+
cts, err := rows.ColumnTypes()
457+
if err != nil {
458+
return nil, err
459+
}
460+
h.colTypes = make([]string, len(cts))
461+
for i, ct := range cts {
462+
h.colTypes[i] = ct.DatabaseTypeName()
463+
}
455464
return sqlutils.RowsToStrMatrix(rows)
456465
}
457466

@@ -509,6 +518,95 @@ func (h *queryComparisonHelper) makeError(err error, msg string) error {
509518
return errors.Wrapf(err, "%s. %d statements run", msg, h.stmtNo)
510519
}
511520

521+
func joinAndSortRows(rowMatrix1, rowMatrix2 [][]string, sep string) (rows1, rows2 []string) {
522+
for _, row := range rowMatrix1 {
523+
rows1 = append(rows1, strings.Join(row[:], sep))
524+
}
525+
for _, row := range rowMatrix2 {
526+
rows2 = append(rows2, strings.Join(row[:], sep))
527+
}
528+
sort.Strings(rows1)
529+
sort.Strings(rows2)
530+
return rows1, rows2
531+
}
532+
533+
// unsortedMatricesDiffWithFloatComp sorts and compares the rows in rowMatrix1
534+
// to rowMatrix2 and outputs a diff or message related to the comparison. If a
535+
// string comparison of the rows fails, and they contain floats or decimals, it
536+
// performs an approximate comparison of the values.
537+
func unsortedMatricesDiffWithFloatComp(
538+
rowMatrix1, rowMatrix2 [][]string, colTypes []string,
539+
) (string, error) {
540+
rows1, rows2 := joinAndSortRows(rowMatrix1, rowMatrix2, ",")
541+
result := cmp.Diff(rows1, rows2)
542+
if result == "" {
543+
return result, nil
544+
}
545+
if len(rows1) != len(rows2) || len(colTypes) != len(rowMatrix1[0]) || len(colTypes) != len(rowMatrix2[0]) {
546+
return result, nil
547+
}
548+
var needApproxMatch bool
549+
for i := range colTypes {
550+
// On s390x, check that values for both float and decimal coltypes are
551+
// approximately equal to take into account platform differences in floating
552+
// point calculations. On other architectures, check float values only.
553+
if (runtime.GOARCH == "s390x" && colTypes[i] == "DECIMAL") ||
554+
colTypes[i] == "FLOAT4" || colTypes[i] == "FLOAT8" {
555+
needApproxMatch = true
556+
break
557+
}
558+
}
559+
if !needApproxMatch {
560+
return result, nil
561+
}
562+
// Use an unlikely string as a separator so that we can make a comparison
563+
// using sorted rows. We don't use the rows sorted above because splitting
564+
// the rows could be ambiguous.
565+
sep := ",unsortedMatricesDiffWithFloatComp separator,"
566+
rows1, rows2 = joinAndSortRows(rowMatrix1, rowMatrix2, sep)
567+
for i := range rows1 {
568+
// Split the sorted rows.
569+
row1 := strings.Split(rows1[i], sep)
570+
row2 := strings.Split(rows2[i], sep)
571+
572+
for j := range row1 {
573+
if runtime.GOARCH == "s390x" && colTypes[j] == "DECIMAL" {
574+
// On s390x, check that values for both float and decimal coltypes are
575+
// approximately equal to take into account platform differences in floating
576+
// point calculations. On other architectures, check float values only.
577+
match, err := floatcmp.FloatsMatchApprox(row1[j], row2[j])
578+
if err != nil {
579+
return "", err
580+
}
581+
if !match {
582+
return result, nil
583+
}
584+
} else if colTypes[j] == "FLOAT4" || colTypes[j] == "FLOAT8" {
585+
// Check that float values are approximately equal.
586+
var err error
587+
var match bool
588+
if runtime.GOARCH == "s390x" {
589+
match, err = floatcmp.FloatsMatchApprox(row1[j], row2[j])
590+
} else {
591+
match, err = floatcmp.FloatsMatch(row1[j], row2[j])
592+
}
593+
if err != nil {
594+
return "", err
595+
}
596+
if !match {
597+
return result, nil
598+
}
599+
} else {
600+
// Check that other columns are equal with a string comparison.
601+
if row1[j] != row2[j] {
602+
return result, nil
603+
}
604+
}
605+
}
606+
}
607+
return "", nil
608+
}
609+
512610
// unsortedMatricesDiff sorts and compares rows of data.
513611
func unsortedMatricesDiff(rowMatrix1, rowMatrix2 [][]string) string {
514612
var rows1 []string
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright 2023 The Cockroach Authors.
2+
//
3+
// Use of this software is governed by the Business Source License
4+
// included in the file licenses/BSL.txt.
5+
//
6+
// As of the Change Date specified in that file, in accordance with
7+
// the Business Source License, use of this software will be governed
8+
// by the Apache License, Version 2.0, included in the file
9+
// licenses/APL.txt.
10+
11+
package tests
12+
13+
import (
14+
"testing"
15+
16+
"github.com/cockroachdb/cockroach/pkg/util/leaktest"
17+
)
18+
19+
// TestUnsortedMatricesDiff is a unit test for the
20+
// unsortedMatricesDiffWithFloatComp() and unsortedMatricesDiff() utility
21+
// functions.
22+
func TestUnsortedMatricesDiff(t *testing.T) {
23+
defer leaktest.AfterTest(t)()
24+
tcs := []struct {
25+
name string
26+
colTypes []string
27+
t1, t2 [][]string
28+
exactMatch bool
29+
approxMatch bool
30+
}{
31+
{
32+
name: "float exact match",
33+
colTypes: []string{"FLOAT8"},
34+
t1: [][]string{{"1.2345678901234567"}},
35+
t2: [][]string{{"1.2345678901234567"}},
36+
exactMatch: true,
37+
},
38+
{
39+
name: "float approx match",
40+
colTypes: []string{"FLOAT8"},
41+
t1: [][]string{{"1.2345678901234563"}},
42+
t2: [][]string{{"1.2345678901234564"}},
43+
exactMatch: false,
44+
approxMatch: true,
45+
},
46+
{
47+
name: "float no match",
48+
colTypes: []string{"FLOAT8"},
49+
t1: [][]string{{"1.234567890123"}},
50+
t2: [][]string{{"1.234567890124"}},
51+
exactMatch: false,
52+
approxMatch: false,
53+
},
54+
{
55+
name: "multi float approx match",
56+
colTypes: []string{"FLOAT8", "FLOAT8"},
57+
t1: [][]string{{"1.2345678901234567", "1.2345678901234567"}},
58+
t2: [][]string{{"1.2345678901234567", "1.2345678901234568"}},
59+
exactMatch: false,
60+
approxMatch: true,
61+
},
62+
{
63+
name: "string no match",
64+
colTypes: []string{"STRING"},
65+
t1: [][]string{{"hello"}},
66+
t2: [][]string{{"world"}},
67+
exactMatch: false,
68+
approxMatch: false,
69+
},
70+
{
71+
name: "mixed types match",
72+
colTypes: []string{"STRING", "FLOAT8"},
73+
t1: [][]string{{"hello", "1.2345678901234567"}},
74+
t2: [][]string{{"hello", "1.2345678901234567"}},
75+
exactMatch: true,
76+
},
77+
{
78+
name: "mixed types float approx match",
79+
colTypes: []string{"STRING", "FLOAT8"},
80+
t1: [][]string{{"hello", "1.23456789012345678"}},
81+
t2: [][]string{{"hello", "1.23456789012345679"}},
82+
exactMatch: false,
83+
approxMatch: true,
84+
},
85+
{
86+
name: "mixed types no match",
87+
colTypes: []string{"STRING", "FLOAT8"},
88+
t1: [][]string{{"hello", "1.2345678901234567"}},
89+
t2: [][]string{{"world", "1.2345678901234567"}},
90+
exactMatch: false,
91+
approxMatch: false,
92+
},
93+
{
94+
name: "different col count",
95+
colTypes: []string{"STRING"},
96+
t1: [][]string{{"hello", "1.2345678901234567"}},
97+
t2: [][]string{{"world", "1.2345678901234567"}},
98+
exactMatch: false,
99+
approxMatch: false,
100+
},
101+
{
102+
name: "different row count",
103+
colTypes: []string{"STRING", "FLOAT8"},
104+
t1: [][]string{{"hello", "1.2345678901234567"}, {"aloha", "2.345"}},
105+
t2: [][]string{{"world", "1.2345678901234567"}},
106+
exactMatch: false,
107+
approxMatch: false,
108+
},
109+
{
110+
name: "multi row unsorted",
111+
colTypes: []string{"STRING", "FLOAT8"},
112+
t1: [][]string{{"hello", "1.2345678901234567"}, {"world", "1.2345678901234560"}},
113+
t2: [][]string{{"world", "1.2345678901234560"}, {"hello", "1.2345678901234567"}},
114+
exactMatch: true,
115+
},
116+
}
117+
for _, tc := range tcs {
118+
t.Run(tc.name, func(t *testing.T) {
119+
match := unsortedMatricesDiff(tc.t1, tc.t2)
120+
if tc.exactMatch && match != "" {
121+
t.Fatalf("unsortedMatricesDiff: expected exact match, got diff: %s", match)
122+
} else if !tc.exactMatch && match == "" {
123+
t.Fatalf("unsortedMatricesDiff: expected no exact match, got no diff")
124+
}
125+
126+
var err error
127+
match, err = unsortedMatricesDiffWithFloatComp(tc.t1, tc.t2, tc.colTypes)
128+
if err != nil {
129+
t.Fatal(err)
130+
}
131+
if tc.exactMatch && match != "" {
132+
t.Fatalf("unsortedMatricesDiffWithFloatComp: expected exact match, got diff: %s", match)
133+
} else if !tc.exactMatch && tc.approxMatch && match != "" {
134+
t.Fatalf("unsortedMatricesDiffWithFloatComp: expected approx match, got diff: %s", match)
135+
} else if !tc.exactMatch && !tc.approxMatch && match == "" {
136+
t.Fatalf("unsortedMatricesDiffWithFloatComp: expected no approx match, got no diff")
137+
}
138+
})
139+
}
140+
}

pkg/cmd/roachtest/tests/unoptimized_query_oracle.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,11 @@ func runUnoptimizedQueryOracleImpl(
171171
//nolint:returnerrcheck
172172
return nil
173173
}
174-
if diff := unsortedMatricesDiff(unoptimizedRows, optimizedRows); diff != "" {
174+
diff, err := unsortedMatricesDiffWithFloatComp(unoptimizedRows, optimizedRows, h.colTypes)
175+
if err != nil {
176+
return err
177+
}
178+
if diff != "" {
175179
// We have a mismatch in the unoptimized vs optimized query outputs.
176180
verboseLogging = true
177181
return h.makeError(errors.Newf(

pkg/testutils/floatcmp/floatcmp.go

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -121,23 +121,9 @@ func FloatsMatch(expectedString, actualString string) (bool, error) {
121121
actual = math.Abs(actual)
122122
// Check that 15 significant digits match. We do so by normalizing the
123123
// numbers and then checking one digit at a time.
124-
//
125-
// normalize converts f to base * 10**power representation where base is in
126-
// [1.0, 10.0) range.
127-
normalize := func(f float64) (base float64, power int) {
128-
for f >= 10 {
129-
f = f / 10
130-
power++
131-
}
132-
for f < 1 {
133-
f *= 10
134-
power--
135-
}
136-
return f, power
137-
}
138124
var expPower, actPower int
139-
expected, expPower = normalize(expected)
140-
actual, actPower = normalize(actual)
125+
expected, expPower = math.Frexp(expected)
126+
actual, actPower = math.Frexp(actual)
141127
if expPower != actPower {
142128
return false, nil
143129
}

0 commit comments

Comments
 (0)