Skip to content

Commit faed871

Browse files
committed
Merge branch 'main' into zachmu/cross-schema-fks
2 parents 5b88f92 + b5b0193 commit faed871

31 files changed

+645
-99
lines changed

core/rootvalue.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,16 @@ func (root *RootValue) GetTableNames(ctx context.Context, schemaName string) ([]
303303
// HandlePostMerge implements the interface doltdb.RootValue.
304304
func (root *RootValue) HandlePostMerge(ctx context.Context, ourRoot, theirRoot, ancRoot doltdb.RootValue) (doltdb.RootValue, error) {
305305
// Handle sequences
306+
_, err := root.handlePostSequencesMerge(ctx, ourRoot, theirRoot, ancRoot)
307+
if err != nil {
308+
return nil, err
309+
}
310+
// Handle types
311+
return root.handlePostTypesMerge(ctx, ourRoot, theirRoot, ancRoot)
312+
}
313+
314+
// handlePostSequencesMerge merges sequences.
315+
func (root *RootValue) handlePostSequencesMerge(ctx context.Context, ourRoot, theirRoot, ancRoot doltdb.RootValue) (doltdb.RootValue, error) {
306316
ourSequence, err := ourRoot.(*RootValue).GetSequences(ctx)
307317
if err != nil {
308318
return nil, err
@@ -322,6 +332,27 @@ func (root *RootValue) HandlePostMerge(ctx context.Context, ourRoot, theirRoot,
322332
return root.PutSequences(ctx, mergedSequence)
323333
}
324334

335+
// handlePostTypesMerge merges types.
336+
func (root *RootValue) handlePostTypesMerge(ctx context.Context, ourRoot, theirRoot, ancRoot doltdb.RootValue) (doltdb.RootValue, error) {
337+
ourTypes, err := ourRoot.(*RootValue).GetTypes(ctx)
338+
if err != nil {
339+
return nil, err
340+
}
341+
theirTypes, err := theirRoot.(*RootValue).GetTypes(ctx)
342+
if err != nil {
343+
return nil, err
344+
}
345+
ancTypes, err := ancRoot.(*RootValue).GetTypes(ctx)
346+
if err != nil {
347+
return nil, err
348+
}
349+
mergedTypes, err := typecollection.Merge(ctx, ourTypes, theirTypes, ancTypes)
350+
if err != nil {
351+
return nil, err
352+
}
353+
return root.PutTypes(ctx, mergedTypes)
354+
}
355+
325356
// HashOf implements the interface doltdb.RootValue.
326357
func (root *RootValue) HashOf() (hash.Hash, error) {
327358
if root.hash.IsEmpty() {

core/typecollection/merge.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package typecollection
16+
17+
import (
18+
"context"
19+
"fmt"
20+
21+
"github.com/dolthub/doltgresql/server/types"
22+
)
23+
24+
// Merge handles merging sequences on our root and their root.
25+
func Merge(ctx context.Context, ourCollection, theirCollection, ancCollection *TypeCollection) (*TypeCollection, error) {
26+
mergedCollection := ourCollection.Clone()
27+
err := theirCollection.IterateTypes(func(schema string, theirType *types.Type) error {
28+
// If we don't have the type, then we simply add it
29+
mergedType, exists := mergedCollection.GetType(schema, theirType.Name)
30+
if !exists {
31+
newSeq := *theirType
32+
return mergedCollection.CreateType(schema, &newSeq)
33+
}
34+
35+
// Different types with the same name cannot be merged. (e.g.: 'domain' type and 'base' type with the same name)
36+
if mergedType.TypType != theirType.TypType {
37+
return fmt.Errorf(`cannot merge type "%s" because type types do not match: '%s' and '%s'"`, theirType.Name, mergedType.TypType, theirType.TypType)
38+
}
39+
40+
switch theirType.TypType {
41+
case types.TypeType_Domain:
42+
if mergedType.BaseTypeOID != theirType.BaseTypeOID {
43+
// TODO: we can extend on this in the future (e.g.: maybe uses preferred type?)
44+
return fmt.Errorf(`base types of domain type "%s" do not match`, theirType.Name)
45+
}
46+
if mergedType.Default == "" {
47+
mergedType.Default = theirType.Default
48+
} else if theirType.Default != "" && mergedType.Default != theirType.Default {
49+
return fmt.Errorf(`default values of domain type "%s" do not match`, theirType.Name)
50+
}
51+
// if either of types defined as NOT NULL, take NOT NULL
52+
if mergedType.NotNull || theirType.NotNull {
53+
mergedType.NotNull = true
54+
}
55+
if len(theirType.Checks) > 0 {
56+
// TODO: check for duplicate check constraints
57+
mergedType.Checks = append(mergedType.Checks, theirType.Checks...)
58+
}
59+
default:
60+
// TODO: support merge for other types. (base, range, etc.)
61+
}
62+
return nil
63+
})
64+
if err != nil {
65+
return nil, err
66+
}
67+
return mergedCollection, nil
68+
}

core/typecollection/serialization.go

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,13 @@ func (pgs *TypeCollection) Serialize(ctx context.Context) ([]byte, error) {
4646
writer.VariableUint(uint64(len(nameMapKeys)))
4747
for _, nameMapKey := range nameMapKeys {
4848
typ := nameMap[nameMapKey]
49+
writer.Uint32(typ.Oid)
4950
writer.String(typ.Name)
5051
writer.String(typ.Owner)
5152
writer.Int16(typ.Length)
5253
writer.Bool(typ.PassedByVal)
53-
writer.String(string(typ.Typ))
54-
writer.String(string(typ.Category))
54+
writer.String(string(typ.TypType))
55+
writer.String(string(typ.TypCategory))
5556
writer.Bool(typ.IsPreferred)
5657
writer.Bool(typ.IsDefined)
5758
writer.String(typ.Delimiter)
@@ -110,13 +111,14 @@ func Deserialize(ctx context.Context, data []byte) (*TypeCollection, error) {
110111
numOfTypes := reader.VariableUint()
111112
nameMap := make(map[string]*types.Type)
112113
for j := uint64(0); j < numOfTypes; j++ {
113-
typ := &types.Type{}
114+
typ := &types.Type{Schema: schemaName}
115+
typ.Oid = reader.Uint32()
114116
typ.Name = reader.String()
115117
typ.Owner = reader.String()
116118
typ.Length = reader.Int16()
117119
typ.PassedByVal = reader.Bool()
118-
typ.Typ = types.TypeType(reader.String())
119-
typ.Category = types.TypeCategory(reader.String())
120+
typ.TypType = types.TypeType(reader.String())
121+
typ.TypCategory = types.TypeCategory(reader.String())
120122
typ.IsPreferred = reader.Bool()
121123
typ.IsDefined = reader.Bool()
122124
typ.Delimiter = reader.String()

core/typecollection/typecollection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func (pgs *TypeCollection) GetDomainType(schName, typName string) (*types.Type,
4848
defer pgs.mutex.RUnlock()
4949

5050
if nameMap, ok := pgs.schemaMap[schName]; ok {
51-
if typ, ok := nameMap[typName]; ok && typ.Typ == types.TypeType_Domain {
51+
if typ, ok := nameMap[typName]; ok && typ.TypType == types.TypeType_Domain {
5252
return typ, true
5353
}
5454
}

postgres/parser/parser/parse.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import (
3838
"strings"
3939

4040
"github.com/cockroachdb/errors"
41+
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
4142

4243
"github.com/dolthub/doltgresql/postgres/parser/pgcode"
4344
"github.com/dolthub/doltgresql/postgres/parser/pgerror"
@@ -120,8 +121,10 @@ func (p *Parser) parseOneWithDepth(depth int, sql string) (Statement, error) {
120121
if err != nil {
121122
return Statement{}, err
122123
}
123-
if len(stmts) != 1 {
124+
if len(stmts) > 1 {
124125
return Statement{}, errors.AssertionFailedf("expected 1 statement, but found %d", len(stmts))
126+
} else if len(stmts) == 0 {
127+
return Statement{}, vitess.ErrEmpty
125128
}
126129
return stmts[0], nil
127130
}

postgres/parser/parser/sql/sql_parser.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func (p *PostgresParser) ParseWithOptions(ctx context.Context, query string, del
5656
return nil, "", "", fmt.Errorf("only a single statement at a time is currently supported")
5757
}
5858
if len(stmts) == 0 {
59-
return nil, q, "", nil
59+
return nil, q, "", vitess.ErrEmpty
6060
}
6161

6262
vitessAST, err := ast.Convert(stmts[0])

server/ast/drop_domain.go

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,28 @@ import (
2020
vitess "github.com/dolthub/vitess/go/vt/sqlparser"
2121

2222
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
23+
pgnodes "github.com/dolthub/doltgresql/server/node"
2324
)
2425

2526
// nodeDropDomain handles *tree.DropDomain nodes.
2627
func nodeDropDomain(node *tree.DropDomain) (vitess.Statement, error) {
27-
return nil, fmt.Errorf("DROP DOMAIN is not supported yet")
28-
//if node == nil {
29-
// return nil, nil
30-
//}
31-
//if len(node.Names) != 1 {
32-
// return nil, fmt.Errorf("dropping multiple domains in DROP DOMAIN is not yet supported")
33-
//}
34-
//name, err := nodeTableName(&node.Names[0])
35-
//if err != nil {
36-
// return nil, err
37-
//}
38-
//if len(name.DbQualifier.String()) > 0 {
39-
// return nil, fmt.Errorf("DROP DOMAIN is currently only supported for the current database")
40-
//}
41-
//return vitess.InjectedStatement{
42-
// Statement: pgnodes.NewDropDomain(node.IfExists, name.SchemaQualifier.String(), name.Name.String(), node.DropBehavior == tree.DropCascade),
43-
// Children: nil,
44-
//}, nil
28+
if node == nil {
29+
return nil, nil
30+
}
31+
if len(node.Names) != 1 {
32+
return nil, fmt.Errorf("dropping multiple domains in DROP DOMAIN is not yet supported")
33+
}
34+
name, err := nodeTableName(&node.Names[0])
35+
if err != nil {
36+
return nil, err
37+
}
38+
return vitess.InjectedStatement{
39+
Statement: pgnodes.NewDropDomain(
40+
node.IfExists,
41+
name.DbQualifier.String(),
42+
name.SchemaQualifier.String(),
43+
name.Name.String(),
44+
node.DropBehavior == tree.DropCascade),
45+
Children: nil,
46+
}, nil
4547
}

server/functions/current_schema.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"github.com/dolthub/go-mysql-server/sql"
1919

2020
"github.com/dolthub/doltgresql/server/functions/framework"
21+
"github.com/dolthub/doltgresql/server/settings"
2122
pgtypes "github.com/dolthub/doltgresql/server/types"
2223
)
2324

@@ -33,7 +34,7 @@ var current_schema = framework.Function0{
3334
IsNonDeterministic: true,
3435
Strict: true,
3536
Callable: func(ctx *sql.Context) (any, error) {
36-
schemas, err := GetCurrentSchemas(ctx)
37+
schemas, err := settings.GetCurrentSchemas(ctx)
3738
if err != nil {
3839
return nil, err
3940
}

server/functions/current_schemas.go

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,11 @@
1515
package functions
1616

1717
import (
18-
"strings"
19-
2018
"github.com/dolthub/go-mysql-server/sql"
2119

2220
"github.com/dolthub/doltgresql/postgres/parser/sessiondata"
2321
"github.com/dolthub/doltgresql/server/functions/framework"
22+
"github.com/dolthub/doltgresql/server/settings"
2423
pgtypes "github.com/dolthub/doltgresql/server/types"
2524
)
2625

@@ -41,7 +40,7 @@ var current_schemas = framework.Function1{
4140
if val.(bool) {
4241
schemas = append(schemas, sessiondata.PgCatalogName)
4342
}
44-
searchPaths, err := GetCurrentSchemas(ctx)
43+
searchPaths, err := settings.GetCurrentSchemas(ctx)
4544
if err != nil {
4645
return nil, err
4746
}
@@ -51,24 +50,3 @@ var current_schemas = framework.Function1{
5150
return schemas, nil
5251
},
5352
}
54-
55-
// GetCurrentSchemas returns all the schemas in the search_path setting, with elements like "$user" excluded
56-
func GetCurrentSchemas(ctx *sql.Context) ([]string, error) {
57-
searchPathVar, err := ctx.GetSessionVariable(ctx, "search_path")
58-
if err != nil {
59-
return nil, err
60-
}
61-
62-
pathElems := strings.Split(searchPathVar.(string), ",")
63-
var path []string
64-
65-
for _, schemaName := range pathElems {
66-
schemaName = strings.Trim(schemaName, " ")
67-
if schemaName == "\"$user\"" {
68-
continue
69-
}
70-
path = append(path, schemaName)
71-
}
72-
73-
return path, nil
74-
}

server/functions/init.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ func Init() {
106106
initRpad()
107107
initRtrim()
108108
initScale()
109+
initSetConfig()
109110
initSetVal()
110111
initShobjDescription()
111112
initSign()

0 commit comments

Comments
 (0)