Skip to content

Commit c84a300

Browse files
authored
Merge pull request #879 from dolthub/fulghum/nextval_regclass
Allow `nextval()` to take a regclass instance
2 parents d536220 + f964380 commit c84a300

File tree

10 files changed

+209
-30
lines changed

10 files changed

+209
-30
lines changed

server/functions/current_schema.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/dolthub/go-mysql-server/sql"
1919

2020
"github.com/dolthub/doltgresql/server/functions/framework"
21+
"github.com/dolthub/doltgresql/server/settings"
2122
pgtypes "github.com/dolthub/doltgresql/server/types"
2223
)
2324

@@ -33,7 +34,7 @@ var current_schema = framework.Function0{
3334
IsNonDeterministic: true,
3435
Strict: true,
3536
Callable: func(ctx *sql.Context) (any, error) {
36-
schemas, err := GetCurrentSchemas(ctx)
37+
schemas, err := settings.GetCurrentSchemas(ctx)
3738
if err != nil {
3839
return nil, err
3940
}

server/functions/current_schemas.go

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
package functions
1616

1717
import (
18-
"strings"
19-
2018
"github.com/dolthub/go-mysql-server/sql"
2119

2220
"github.com/dolthub/doltgresql/postgres/parser/sessiondata"
2321
"github.com/dolthub/doltgresql/server/functions/framework"
22+
"github.com/dolthub/doltgresql/server/settings"
2423
pgtypes "github.com/dolthub/doltgresql/server/types"
2524
)
2625

@@ -41,7 +40,7 @@ var current_schemas = framework.Function1{
4140
if val.(bool) {
4241
schemas = append(schemas, sessiondata.PgCatalogName)
4342
}
44-
searchPaths, err := GetCurrentSchemas(ctx)
43+
searchPaths, err := settings.GetCurrentSchemas(ctx)
4544
if err != nil {
4645
return nil, err
4746
}
@@ -51,24 +50,3 @@ var current_schemas = framework.Function1{
5150
return schemas, nil
5251
},
5352
}
54-
55-
// GetCurrentSchemas returns all the schemas in the search_path setting, with elements like "$user" excluded
56-
func GetCurrentSchemas(ctx *sql.Context) ([]string, error) {
57-
searchPathVar, err := ctx.GetSessionVariable(ctx, "search_path")
58-
if err != nil {
59-
return nil, err
60-
}
61-
62-
pathElems := strings.Split(searchPathVar.(string), ",")
63-
var path []string
64-
65-
for _, schemaName := range pathElems {
66-
schemaName = strings.Trim(schemaName, " ")
67-
if schemaName == "\"$user\"" {
68-
continue
69-
}
70-
path = append(path, schemaName)
71-
}
72-
73-
return path, nil
74-
}

server/functions/nextval.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,15 @@ import (
2525
// initNextVal registers the functions to the catalog.
2626
func initNextVal() {
2727
framework.RegisterFunction(nextval_text)
28+
framework.RegisterFunction(nextval_regclass)
2829
}
2930

3031
// nextval_text represents the PostgreSQL function of the same name, taking the same parameters.
32+
//
33+
// TODO: Even though we can implicitly convert a text param to a regclass param, it's an expensive process
34+
// to convert it to a regclass, then convert the regclass back into the relation name, so we provide an overload
35+
// that takes a text param directly, in addition to the function form that takes a regclass. Once we can optimize
36+
// the regclass to text conversion, we can potentially remove this overload.
3137
var nextval_text = framework.Function1{
3238
Name: "nextval",
3339
Return: pgtypes.Int64,
@@ -47,3 +53,29 @@ var nextval_text = framework.Function1{
4753
return collection.NextVal(schema, sequence)
4854
},
4955
}
56+
57+
// nextval_regclass represents the PostgreSQL function of the same name, taking the same parameters.
58+
var nextval_regclass = framework.Function1{
59+
Name: "nextval",
60+
Return: pgtypes.Int64,
61+
Parameters: [1]pgtypes.DoltgresType{pgtypes.Regclass},
62+
IsNonDeterministic: true,
63+
Strict: true,
64+
Callable: func(ctx *sql.Context, _ [2]pgtypes.DoltgresType, val any) (any, error) {
65+
relationName, err := pgtypes.Regclass.IoOutput(ctx, val)
66+
if err != nil {
67+
return nil, err
68+
}
69+
70+
schema, sequence, err := parseRelationName(ctx, relationName)
71+
if err != nil {
72+
return nil, err
73+
}
74+
75+
collection, err := core.GetSequencesCollectionFromContext(ctx)
76+
if err != nil {
77+
return nil, err
78+
}
79+
return collection.NextVal(schema, sequence)
80+
},
81+
}

server/settings/current_schemas.go

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package settings
16+
17+
import (
18+
"strings"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
)
22+
23+
// GetCurrentSchemas returns all the schemas in the search_path setting, with elements like "$user" excluded
24+
func GetCurrentSchemas(ctx *sql.Context) ([]string, error) {
25+
searchPathVar, err := ctx.GetSessionVariable(ctx, "search_path")
26+
if err != nil {
27+
return nil, err
28+
}
29+
30+
pathElems := strings.Split(searchPathVar.(string), ",")
31+
var path []string
32+
33+
for _, schemaName := range pathElems {
34+
schemaName = strings.Trim(schemaName, " ")
35+
if schemaName == "\"$user\"" {
36+
continue
37+
}
38+
path = append(path, schemaName)
39+
}
40+
41+
return path, nil
42+
}
43+
44+
// GetCurrentSchemasAsMap returns the schemas from the search_path setting as a map for easy lookup. Any
45+
// elements like "$user" are excluded.
46+
func GetCurrentSchemasAsMap(ctx *sql.Context) (map[string]struct{}, error) {
47+
schemas, err := GetCurrentSchemas(ctx)
48+
if err != nil {
49+
return nil, err
50+
}
51+
schemaMap := make(map[string]struct{}, len(schemas))
52+
for _, schema := range schemas {
53+
schemaMap[schema] = struct{}{}
54+
}
55+
return schemaMap, nil
56+
}

server/settings/doc.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
// Package settings provides utility functions for working with settings, such as the search_path setting.
16+
//
17+
// This package is intended to be a leaf in the package dependency graph, and should not add dependencies to
18+
// other Doltgres packages.
19+
package settings

server/types/oid/regclass.go

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ import (
2020

2121
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
2222
"github.com/dolthub/go-mysql-server/sql"
23+
24+
"github.com/dolthub/doltgresql/server/settings"
2325
)
2426

2527
// regclass_IoInput is the implementation for IoInput that avoids circular dependencies by being declared in a separate
@@ -109,25 +111,56 @@ func regclass_IoInput(ctx *sql.Context, input string) (uint32, error) {
109111
// regclass_IoOutput is the implementation for IoOutput that avoids circular dependencies by being declared in a separate
110112
// package.
111113
func regclass_IoOutput(ctx *sql.Context, oid uint32) (string, error) {
114+
// Find all the schemas on the search path. If a schema is on the search path, then it is not included in the
115+
// output of relation name. If the relation's schema is not on the search path, then it is explicitly included.
116+
schemasMap, err := settings.GetCurrentSchemasAsMap(ctx)
117+
if err != nil {
118+
return "", err
119+
}
120+
121+
// The pg_catalog schema is always implicitly part of the search path
122+
// https://www.postgresql.org/docs/current/ddl-schemas.html#DDL-SCHEMAS-CATALOG
123+
schemasMap["pg_catalog"] = struct{}{}
124+
112125
output := strconv.FormatUint(uint64(oid), 10)
113-
err := RunCallback(ctx, oid, Callbacks{
126+
err = RunCallback(ctx, oid, Callbacks{
114127
Index: func(ctx *sql.Context, schema ItemSchema, table ItemTable, index ItemIndex) (cont bool, err error) {
115128
output = index.Item.ID()
116129
if output == "PRIMARY" {
117-
output = fmt.Sprintf("%s_pkey", index.Item.Table())
130+
schemaName := schema.Item.SchemaName()
131+
if _, ok := schemasMap[schemaName]; ok {
132+
output = fmt.Sprintf("%s_pkey", index.Item.Table())
133+
} else {
134+
output = fmt.Sprintf("%s.%s_pkey", schemaName, index.Item.Table())
135+
}
118136
}
119137
return false, nil
120138
},
121139
Sequence: func(ctx *sql.Context, schema ItemSchema, sequence ItemSequence) (cont bool, err error) {
122-
output = sequence.Item.Name
140+
schemaName := schema.Item.SchemaName()
141+
if _, ok := schemasMap[schemaName]; ok {
142+
output = sequence.Item.Name
143+
} else {
144+
output = fmt.Sprintf("%s.%s", schemaName, sequence.Item.Name)
145+
}
123146
return false, nil
124147
},
125148
Table: func(ctx *sql.Context, schema ItemSchema, table ItemTable) (cont bool, err error) {
126-
output = table.Item.Name()
149+
schemaName := schema.Item.SchemaName()
150+
if _, ok := schemasMap[schemaName]; ok {
151+
output = table.Item.Name()
152+
} else {
153+
output = fmt.Sprintf("%s.%s", schemaName, table.Item.Name())
154+
}
127155
return false, nil
128156
},
129157
View: func(ctx *sql.Context, schema ItemSchema, view ItemView) (cont bool, err error) {
130-
output = view.Item.Name
158+
schemaName := schema.Item.SchemaName()
159+
if _, ok := schemasMap[schemaName]; ok {
160+
output = view.Item.Name
161+
} else {
162+
output = fmt.Sprintf("%s.%s", schemaName, view.Item.Name)
163+
}
131164
return false, nil
132165
},
133166
})

server/types/text.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ func (b TextType) Convert(val any) (any, sql.ConvertInRange, error) {
106106

107107
// Equals implements the DoltgresType interface.
108108
func (b TextType) Equals(otherType sql.Type) bool {
109+
if _, ok := otherType.(TextType); !ok {
110+
return false
111+
}
112+
109113
if otherExtendedType, ok := otherType.(types.ExtendedType); ok {
110114
return bytes.Equal(MustSerializeType(b), MustSerializeType(otherExtendedType))
111115
}

testing/go/functions_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,24 @@ func TestFunctionsOID(t *testing.T) {
202202
{nil},
203203
},
204204
},
205+
{
206+
// When the relation is from a schema on the search path, it is not qualified with the schema name
207+
Query: `SELECT to_regclass(('public.testing'::regclass)::text);`,
208+
Expected: []sql.Row{
209+
{"testing"},
210+
},
211+
},
212+
{
213+
// Clear out the current search_path setting to test fully qualified relation names
214+
Query: `SET search_path = '';`,
215+
Expected: []sql.Row{},
216+
},
217+
{
218+
Query: `SELECT to_regclass(('public.testing'::regclass)::text);`,
219+
Expected: []sql.Row{
220+
{"public.testing"},
221+
},
222+
},
205223
},
206224
},
207225
{

testing/go/sequences_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@ func TestSequences(t *testing.T) {
4141
Query: "SELECT nextval('test');",
4242
Expected: []sql.Row{{3}},
4343
},
44+
{
45+
Query: "SELECT nextval('test'::regclass);",
46+
Expected: []sql.Row{{4}},
47+
},
48+
{
49+
Query: "SELECT nextval('doesnotexist'::regclass);",
50+
ExpectedErr: "does not exist",
51+
},
4452
{
4553
Query: "DROP SEQUENCE test;",
4654
Expected: []sql.Row{},

testing/go/types_test.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1798,6 +1798,7 @@ var typesTests = []ScriptTest{
17981798
`CREATE TABLE testing (pk INT primary key, v1 INT UNIQUE);`,
17991799
`CREATE TABLE "Testing2" (pk INT primary key, v1 INT);`,
18001800
`CREATE VIEW testview AS SELECT * FROM testing LIMIT 1;`,
1801+
`CREATE SEQUENCE seq1;`,
18011802
},
18021803
Assertions: []ScriptTestAssertion{
18031804
{
@@ -1838,6 +1839,12 @@ var typesTests = []ScriptTest{
18381839
{"testing"},
18391840
},
18401841
},
1842+
{
1843+
Query: `SELECT 'seq1'::regclass;`,
1844+
Expected: []sql.Row{
1845+
{"seq1"},
1846+
},
1847+
},
18411848
{
18421849
Query: `SELECT 'Testing2'::regclass;`,
18431850
ExpectedErr: "does not exist",
@@ -1860,6 +1867,29 @@ var typesTests = []ScriptTest{
18601867
{"testing"},
18611868
},
18621869
},
1870+
{
1871+
// schema-qualified relation names are not returned if the schema is on the search path
1872+
Query: `SELECT 'public.testing'::regclass, 'public.seq1'::regclass, 'public.testview'::regclass, 'public.testing_pkey'::regclass;`,
1873+
Expected: []sql.Row{
1874+
{"testing", "seq1", "testview", "testing_pkey"},
1875+
},
1876+
},
1877+
{
1878+
// Clear out the current search_path setting to test schema-qualified relation names
1879+
Query: `SET search_path = '';`,
1880+
Expected: []sql.Row{},
1881+
},
1882+
{
1883+
// Without 'public' on search_path, we get a does not exist error
1884+
Query: `SELECT 'testing'::regclass;`,
1885+
ExpectedErr: "does not exist",
1886+
},
1887+
{
1888+
Query: `SELECT 'public.testing'::regclass, 'public.seq1'::regclass, 'public.testview'::regclass, 'public.testing_pkey'::regclass;`,
1889+
Expected: []sql.Row{
1890+
{"public.testing", "public.seq1", "public.testview", "public.testing_pkey"},
1891+
},
1892+
},
18631893
},
18641894
},
18651895
{

0 commit comments

Comments
 (0)