Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions postgres/parser/parser/sql.y
Original file line number Diff line number Diff line change
Expand Up @@ -4313,11 +4313,11 @@ create_function_stmt:
}
| CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list
{
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()}
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()}
}
| CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list
{
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()}
$$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsTable: true, RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()}
}
| CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list create_function_option_list
{
Expand All @@ -4329,11 +4329,11 @@ create_function_stmt:
}
| CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list
{
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()}
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()}
}
| CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list
{
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()}
$$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsTable: true, RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()}
}

opt_returns_table_col_def_list:
Expand Down
17 changes: 9 additions & 8 deletions postgres/parser/sem/tree/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@ var _ Statement = &CreateFunction{}

// CreateFunction represents a CREATE FUNCTION statement.
type CreateFunction struct {
Name *UnresolvedObjectName
Replace bool
Args RoutineArgs
SetOf bool
RetType []SimpleColumnDef
Options []RoutineOption
Name *UnresolvedObjectName
Replace bool
Args RoutineArgs
ReturnsSetOf bool
ReturnsTable bool
RetType []SimpleColumnDef
Options []RoutineOption
}

// Format implements the NodeFormatter interface.
Expand All @@ -47,9 +48,9 @@ func (node *CreateFunction) Format(ctx *FmtCtx) {
ctx.WriteString(" )")
}
if node.RetType != nil {
if len(node.RetType) == 1 && node.RetType[0].Name == "" {
if !node.ReturnsTable {
ctx.WriteString("RETURNS ")
if node.SetOf {
if node.ReturnsSetOf {
ctx.WriteString("SETOF ")
}
ctx.WriteString(node.RetType[0].Type.SQLString())
Expand Down
49 changes: 42 additions & 7 deletions server/ast/create_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/cockroachdb/errors"
vitess "github.com/dolthub/vitess/go/vt/sqlparser"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/postgres/parser/parser"
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
"github.com/dolthub/doltgresql/postgres/parser/types"
Expand All @@ -38,20 +39,26 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
}
// Grab the general information that we'll need to create the function
tableName := node.Name.ToTableName()
retType := pgtypes.Void
if len(node.RetType) == 1 {
var retType *pgtypes.DoltgresType
if len(node.RetType) == 0 {
retType = pgtypes.Void
} else if !node.ReturnsTable {
switch typ := node.RetType[0].Type.(type) {
case *types.T:
retType = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(typ.Name()))
default:
sqlString := strings.ToLower(typ.SQLString())
if sqlString == "trigger" {
case *tree.UnresolvedObjectName:
if strings.ToLower(typ.SQLString()) == "trigger" {
retType = pgtypes.Trigger
} else {
retType = pgtypes.NewUnresolvedDoltgresType("", sqlString)
retType = pgtypes.NewUnresolvedDoltgresType(typ.Schema(), typ.Object())
}
default:
return nil, fmt.Errorf("unsupported ResolvableTypeReference type: %T", typ)
}
} else {
retType = createAnonymousCompositeType(node.RetType)
}

paramNames := make([]string, len(node.Args))
paramTypes := make([]*pgtypes.DoltgresType, len(node.Args))
for i, arg := range node.Args {
Expand Down Expand Up @@ -121,11 +128,39 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
parsedBody,
sqlDef,
sqlDefParsed,
node.SetOf,
node.ReturnsSetOf,
),
}, nil
}

// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return
// type for a function, as represented by the |fieldTypes| that were specified in the function
// definition.
func createAnonymousCompositeType(fieldTypes []tree.SimpleColumnDef) *pgtypes.DoltgresType {
attrs := make([]pgtypes.CompositeAttribute, len(fieldTypes))
for i, fieldType := range fieldTypes {
attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, fieldType.Name.String(),
id.NewType("", fieldType.Type.SQLString()), int16(i), "")
}

typeIdString := "table("
for i, attr := range attrs {
if i > 0 {
typeIdString += ","
}
typeIdString += attr.Name
typeIdString += ":"
typeIdString += attr.TypeID.TypeName()
}
typeIdString += ")"

// NOTE: there is no schema needed, since these types are anonymous and can't be directly referenced
typeId := id.NewType("", typeIdString)

// NOTE: The available ctx is an ast.Context, not a sql.Context instance, so we can't use it here
return pgtypes.NewCompositeType(nil, id.Null, id.NullType, typeId, attrs)
}

// handleLanguageSQL handles parsing SQL definition strings in both CREATE FUNCTION and CREATE PROCEDURE.
func handleLanguageSQL(definition string, paramNames []string, paramTypes []*pgtypes.DoltgresType) (string, vitess.Statement, error) {
stmt, err := parser.ParseOne(definition)
Expand Down
53 changes: 48 additions & 5 deletions server/functions/framework/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,19 @@
package framework

import (
"github.com/dolthub/go-mysql-server/sql"
"strings"

"github.com/dolthub/doltgresql/core/extensions"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core"
"github.com/dolthub/doltgresql/core/extensions"
"github.com/dolthub/doltgresql/core/id"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

const anonymousCompositePrefix = "table("
const anonymousCompositeSuffix = ")"

// FunctionProvider is the special sql.FunctionProvider for Doltgres that allows us to handle functions that
// are created by users.
type FunctionProvider struct{}
Expand Down Expand Up @@ -67,10 +71,18 @@ func (fp *FunctionProvider) Function(ctx *sql.Context, name string) (sql.Functio

overloadTree := NewOverloads()
for _, overload := range overloads {
returnType, err := typesCollection.GetType(ctx, overload.ReturnType)
if err != nil || returnType == nil {
return nil, false
var returnType *pgtypes.DoltgresType
if isAnonymousCompositeType(overload.ReturnType) {
// If this is an anonymous composite type, then we can't load it
// from typesCollection, so we create it dynamically when needed.
returnType = createAnonymousCompositeType(ctx, overload.ReturnType)
} else {
returnType, err = typesCollection.GetType(ctx, overload.ReturnType)
if err != nil || returnType == nil {
return nil, false
}
}

paramTypes := make([]*pgtypes.DoltgresType, len(overload.ParameterTypes))
for i, paramType := range overload.ParameterTypes {
paramTypes[i], err = typesCollection.GetType(ctx, paramType)
Expand Down Expand Up @@ -127,3 +139,34 @@ func (fp *FunctionProvider) Function(ctx *sql.Context, name string) (sql.Functio
},
}, true
}

// isAnonymousCompositeType return true if |returnType| represents an anonymous composite return type
// for a function (i.e. the function was declared as "RETURNS TABLE(...)").
func isAnonymousCompositeType(returnType id.Type) bool {
typeName := returnType.TypeName()
return strings.HasPrefix(typeName, anonymousCompositePrefix) &&
strings.HasSuffix(typeName, anonymousCompositeSuffix)
}

// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return type for a function,
// as represented by |returnType|.
func createAnonymousCompositeType(ctx *sql.Context, returnType id.Type) *pgtypes.DoltgresType {
typeName := returnType.TypeName()
attributeTypes := typeName[len(anonymousCompositePrefix) : len(typeName)-len(anonymousCompositeSuffix)]
attributeTypesSlice := strings.Split(attributeTypes, ",")

attrs := make([]pgtypes.CompositeAttribute, len(attributeTypesSlice))
for i, attributeNameAndType := range attributeTypesSlice {
split := strings.Split(attributeNameAndType, ":")
if len(split) != 2 {
// TODO: We could return an error here, but the only place this function is
// called (FunctionProvider.Function) would require updating the
// sql.FunctionProvider interface in GMS, too.
panic("unexpected anonymous composite type attribute syntax: " + attributeNameAndType)
}

typeId := id.NewType("", split[1])
attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, split[0], typeId, int16(i), "")
}
return pgtypes.NewCompositeType(ctx, id.Null, id.NullType, returnType, attrs)
}
5 changes: 4 additions & 1 deletion server/plpgsql/json_convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,10 @@ func jsonConvert(jsonBlock plpgSQL_block) (Block, error) {
block.Records = make([]Record, v.Record.DatumNumber)
copy(block.Records, oldRecords)
}
block.Records[v.Record.DatumNumber-1].Name = v.Record.RefName

if v.Record.DatumNumber > 0 {
block.Records[v.Record.DatumNumber-1].Name = v.Record.RefName
}
case v.RecordField != nil:
if int(v.RecordField.RecordParentNumber) > len(block.Records) {
return Block{}, errors.New("invalid record parent number")
Expand Down
154 changes: 154 additions & 0 deletions testing/go/create_function_plpgsql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,160 @@ $$ LANGUAGE plpgsql;`},
},
},
},
{
Name: "RETURNS SETOF with type from other schema",
SetUpScript: []string{
`CREATE SCHEMA sch1;`,
`CREATE TYPE sch1.user_summary AS (
user_id integer,
username text,
is_active boolean);`,
`CREATE OR REPLACE FUNCTION func2() RETURNS SETOF sch1.user_summary
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY SELECT 1, 'username', true;
RETURN QUERY SELECT 2, 'another', false;
END;
$$;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT func2();",
Expected: []sql.Row{
{"(1,username,t)"},
{"(2,another,f)"},
},
},
{
Query: "SELECT func2(), func2();",
Expected: []sql.Row{
{"(1,username,t)", "(1,username,t)"},
{"(2,another,f)", "(2,another,f)"},
},
},
},
},
{
Name: "RETURNS TABLE",
SetUpScript: []string{
`CREATE FUNCTION func2() RETURNS TABLE(user_id integer, username text, is_active boolean)
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY SELECT 1, 'username', true;
RETURN QUERY SELECT 2, 'another', false;
END;
$$;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT func2();",
Expected: []sql.Row{
{"(1,username,t)"},
{"(2,another,f)"},
},
},
{
Query: "SELECT func2(), func2();",
Expected: []sql.Row{
{"(1,username,t)", "(1,username,t)"},
{"(2,another,f)", "(2,another,f)"},
},
},
},
},
{
Name: "RETURNS TABLE with single field",
SetUpScript: []string{
`CREATE FUNCTION func2() RETURNS TABLE(username text)
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY SELECT 'username1';
RETURN QUERY SELECT 'username2';
END;
$$;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT func2();",
Expected: []sql.Row{
{"(username1)"},
{"(username2)"},
},
},
{
Query: "SELECT func2(), func2();",
Expected: []sql.Row{
{"(username1)", "(username1)"},
{"(username2)", "(username2)"},
},
},
},
},
{
Name: "RETURNS TABLE with types from other schema",
SetUpScript: []string{
`CREATE SCHEMA sch1;`,
`CREATE TYPE sch1.mytype AS (
user_id integer,
username text);`,
`CREATE FUNCTION func2() RETURNS TABLE(foo sch1.mytype)
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY SELECT 1, 'username1';
RETURN QUERY SELECT 2, 'username2';
END;
$$;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT func2();",
Expected: []sql.Row{
{"(1,username1)"},
{"(2,username2)"},
},
},
{
Query: "SELECT func2(), func2();",
Expected: []sql.Row{
{"(1,username1)", "(1,username1)"},
{"(2,username2)", "(2,username2)"},
},
},
},
},
{
Name: "RETURNS TABLE with param",
SetUpScript: []string{
`CREATE OR REPLACE FUNCTION func3(user_id integer) RETURNS TABLE(user_id integer, username text, is_active boolean)
LANGUAGE plpgsql
AS $$
BEGIN
RETURN QUERY SELECT user_id, 'username', true;
RETURN QUERY SELECT user_id, 'another', false;
END;
$$;`,
},
Assertions: []ScriptTestAssertion{
{
Query: "SELECT func3(111);",
Expected: []sql.Row{
{"(111,username,t)"},
{"(111,another,f)"},
},
},
{
Query: "SELECT func3(111), func3(222);",
Expected: []sql.Row{
{"(111,username,t)", "(222,username,t)"},
{"(111,another,f)", "(222,another,f)"},
},
},
},
},
{
Name: "RAISE",
SetUpScript: []string{
Expand Down