diff --git a/postgres/parser/parser/sql.y b/postgres/parser/parser/sql.y index 4ebca72521..fce51fedcb 100644 --- a/postgres/parser/parser/sql.y +++ b/postgres/parser/parser/sql.y @@ -4313,11 +4313,11 @@ create_function_stmt: } | CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list { - $$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()} + $$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $7.typeReference()}}, Options: $8.routineOptions()} } | CREATE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list { - $$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()} + $$.val = &tree.CreateFunction{Name: $3.unresolvedObjectName(), Args: $4.routineArgs(), ReturnsTable: true, RetType: $8.simpleColumnDefs(), Options: $10.routineOptions()} } | CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list create_function_option_list { @@ -4329,11 +4329,11 @@ create_function_stmt: } | CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS SETOF typename create_function_option_list { - $$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), SetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()} + $$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsSetOf: true, RetType: []tree.SimpleColumnDef{tree.SimpleColumnDef{Type: $9.typeReference()}}, Options: $10.routineOptions()} } | CREATE OR REPLACE FUNCTION routine_name opt_routine_arg_with_default_list RETURNS TABLE '(' opt_returns_table_col_def_list ')' create_function_option_list { - $$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()} + $$.val = &tree.CreateFunction{Name: $5.unresolvedObjectName(), Replace: true, Args: $6.routineArgs(), ReturnsTable: true, RetType: $10.simpleColumnDefs(), Options: $12.routineOptions()} } opt_returns_table_col_def_list: diff --git a/postgres/parser/sem/tree/create_function.go b/postgres/parser/sem/tree/create_function.go index 5742fcab52..9b9a1f1d31 100644 --- a/postgres/parser/sem/tree/create_function.go +++ b/postgres/parser/sem/tree/create_function.go @@ -25,12 +25,13 @@ var _ Statement = &CreateFunction{} // CreateFunction represents a CREATE FUNCTION statement. type CreateFunction struct { - Name *UnresolvedObjectName - Replace bool - Args RoutineArgs - SetOf bool - RetType []SimpleColumnDef - Options []RoutineOption + Name *UnresolvedObjectName + Replace bool + Args RoutineArgs + ReturnsSetOf bool + ReturnsTable bool + RetType []SimpleColumnDef + Options []RoutineOption } // Format implements the NodeFormatter interface. @@ -47,9 +48,9 @@ func (node *CreateFunction) Format(ctx *FmtCtx) { ctx.WriteString(" )") } if node.RetType != nil { - if len(node.RetType) == 1 && node.RetType[0].Name == "" { + if !node.ReturnsTable { ctx.WriteString("RETURNS ") - if node.SetOf { + if node.ReturnsSetOf { ctx.WriteString("SETOF ") } ctx.WriteString(node.RetType[0].Type.SQLString()) diff --git a/server/ast/create_function.go b/server/ast/create_function.go index 5dc4069d1e..65ed8a80f6 100644 --- a/server/ast/create_function.go +++ b/server/ast/create_function.go @@ -21,6 +21,7 @@ import ( "github.com/cockroachdb/errors" vitess "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/postgres/parser/parser" "github.com/dolthub/doltgresql/postgres/parser/sem/tree" "github.com/dolthub/doltgresql/postgres/parser/types" @@ -38,20 +39,26 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme } // Grab the general information that we'll need to create the function tableName := node.Name.ToTableName() - retType := pgtypes.Void - if len(node.RetType) == 1 { + var retType *pgtypes.DoltgresType + if len(node.RetType) == 0 { + retType = pgtypes.Void + } else if !node.ReturnsTable { switch typ := node.RetType[0].Type.(type) { case *types.T: retType = pgtypes.NewUnresolvedDoltgresType("", strings.ToLower(typ.Name())) - default: - sqlString := strings.ToLower(typ.SQLString()) - if sqlString == "trigger" { + case *tree.UnresolvedObjectName: + if strings.ToLower(typ.SQLString()) == "trigger" { retType = pgtypes.Trigger } else { - retType = pgtypes.NewUnresolvedDoltgresType("", sqlString) + retType = pgtypes.NewUnresolvedDoltgresType(typ.Schema(), typ.Object()) } + default: + return nil, fmt.Errorf("unsupported ResolvableTypeReference type: %T", typ) } + } else { + retType = createAnonymousCompositeType(node.RetType) } + paramNames := make([]string, len(node.Args)) paramTypes := make([]*pgtypes.DoltgresType, len(node.Args)) for i, arg := range node.Args { @@ -121,11 +128,39 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme parsedBody, sqlDef, sqlDefParsed, - node.SetOf, + node.ReturnsSetOf, ), }, nil } +// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return +// type for a function, as represented by the |fieldTypes| that were specified in the function +// definition. +func createAnonymousCompositeType(fieldTypes []tree.SimpleColumnDef) *pgtypes.DoltgresType { + attrs := make([]pgtypes.CompositeAttribute, len(fieldTypes)) + for i, fieldType := range fieldTypes { + attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, fieldType.Name.String(), + id.NewType("", fieldType.Type.SQLString()), int16(i), "") + } + + typeIdString := "table(" + for i, attr := range attrs { + if i > 0 { + typeIdString += "," + } + typeIdString += attr.Name + typeIdString += ":" + typeIdString += attr.TypeID.TypeName() + } + typeIdString += ")" + + // NOTE: there is no schema needed, since these types are anonymous and can't be directly referenced + typeId := id.NewType("", typeIdString) + + // NOTE: The available ctx is an ast.Context, not a sql.Context instance, so we can't use it here + return pgtypes.NewCompositeType(nil, id.Null, id.NullType, typeId, attrs) +} + // handleLanguageSQL handles parsing SQL definition strings in both CREATE FUNCTION and CREATE PROCEDURE. func handleLanguageSQL(definition string, paramNames []string, paramTypes []*pgtypes.DoltgresType) (string, vitess.Statement, error) { stmt, err := parser.ParseOne(definition) diff --git a/server/functions/framework/provider.go b/server/functions/framework/provider.go index bf7143047e..a1f3b272cd 100644 --- a/server/functions/framework/provider.go +++ b/server/functions/framework/provider.go @@ -15,15 +15,19 @@ package framework import ( - "github.com/dolthub/go-mysql-server/sql" + "strings" - "github.com/dolthub/doltgresql/core/extensions" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/extensions" "github.com/dolthub/doltgresql/core/id" pgtypes "github.com/dolthub/doltgresql/server/types" ) +const anonymousCompositePrefix = "table(" +const anonymousCompositeSuffix = ")" + // FunctionProvider is the special sql.FunctionProvider for Doltgres that allows us to handle functions that // are created by users. type FunctionProvider struct{} @@ -67,10 +71,18 @@ func (fp *FunctionProvider) Function(ctx *sql.Context, name string) (sql.Functio overloadTree := NewOverloads() for _, overload := range overloads { - returnType, err := typesCollection.GetType(ctx, overload.ReturnType) - if err != nil || returnType == nil { - return nil, false + var returnType *pgtypes.DoltgresType + if isAnonymousCompositeType(overload.ReturnType) { + // If this is an anonymous composite type, then we can't load it + // from typesCollection, so we create it dynamically when needed. + returnType = createAnonymousCompositeType(ctx, overload.ReturnType) + } else { + returnType, err = typesCollection.GetType(ctx, overload.ReturnType) + if err != nil || returnType == nil { + return nil, false + } } + paramTypes := make([]*pgtypes.DoltgresType, len(overload.ParameterTypes)) for i, paramType := range overload.ParameterTypes { paramTypes[i], err = typesCollection.GetType(ctx, paramType) @@ -127,3 +139,34 @@ func (fp *FunctionProvider) Function(ctx *sql.Context, name string) (sql.Functio }, }, true } + +// isAnonymousCompositeType return true if |returnType| represents an anonymous composite return type +// for a function (i.e. the function was declared as "RETURNS TABLE(...)"). +func isAnonymousCompositeType(returnType id.Type) bool { + typeName := returnType.TypeName() + return strings.HasPrefix(typeName, anonymousCompositePrefix) && + strings.HasSuffix(typeName, anonymousCompositeSuffix) +} + +// createAnonymousCompositeType creates a new DoltgresType for the anonymous composite return type for a function, +// as represented by |returnType|. +func createAnonymousCompositeType(ctx *sql.Context, returnType id.Type) *pgtypes.DoltgresType { + typeName := returnType.TypeName() + attributeTypes := typeName[len(anonymousCompositePrefix) : len(typeName)-len(anonymousCompositeSuffix)] + attributeTypesSlice := strings.Split(attributeTypes, ",") + + attrs := make([]pgtypes.CompositeAttribute, len(attributeTypesSlice)) + for i, attributeNameAndType := range attributeTypesSlice { + split := strings.Split(attributeNameAndType, ":") + if len(split) != 2 { + // TODO: We could return an error here, but the only place this function is + // called (FunctionProvider.Function) would require updating the + // sql.FunctionProvider interface in GMS, too. + panic("unexpected anonymous composite type attribute syntax: " + attributeNameAndType) + } + + typeId := id.NewType("", split[1]) + attrs[i] = pgtypes.NewCompositeAttribute(nil, id.Null, split[0], typeId, int16(i), "") + } + return pgtypes.NewCompositeType(ctx, id.Null, id.NullType, returnType, attrs) +} diff --git a/server/plpgsql/json_convert.go b/server/plpgsql/json_convert.go index 6775a1545d..ad929c5838 100644 --- a/server/plpgsql/json_convert.go +++ b/server/plpgsql/json_convert.go @@ -36,7 +36,10 @@ func jsonConvert(jsonBlock plpgSQL_block) (Block, error) { block.Records = make([]Record, v.Record.DatumNumber) copy(block.Records, oldRecords) } - block.Records[v.Record.DatumNumber-1].Name = v.Record.RefName + + if v.Record.DatumNumber > 0 { + block.Records[v.Record.DatumNumber-1].Name = v.Record.RefName + } case v.RecordField != nil: if int(v.RecordField.RecordParentNumber) > len(block.Records) { return Block{}, errors.New("invalid record parent number") diff --git a/testing/go/create_function_plpgsql_test.go b/testing/go/create_function_plpgsql_test.go index cc948346d5..1192c0465e 100644 --- a/testing/go/create_function_plpgsql_test.go +++ b/testing/go/create_function_plpgsql_test.go @@ -524,6 +524,160 @@ $$ LANGUAGE plpgsql;`}, }, }, }, + { + Name: "RETURNS SETOF with type from other schema", + SetUpScript: []string{ + `CREATE SCHEMA sch1;`, + `CREATE TYPE sch1.user_summary AS ( + user_id integer, + username text, + is_active boolean);`, + `CREATE OR REPLACE FUNCTION func2() RETURNS SETOF sch1.user_summary + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT 1, 'username', true; + RETURN QUERY SELECT 2, 'another', false; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func2();", + Expected: []sql.Row{ + {"(1,username,t)"}, + {"(2,another,f)"}, + }, + }, + { + Query: "SELECT func2(), func2();", + Expected: []sql.Row{ + {"(1,username,t)", "(1,username,t)"}, + {"(2,another,f)", "(2,another,f)"}, + }, + }, + }, + }, + { + Name: "RETURNS TABLE", + SetUpScript: []string{ + `CREATE FUNCTION func2() RETURNS TABLE(user_id integer, username text, is_active boolean) + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT 1, 'username', true; + RETURN QUERY SELECT 2, 'another', false; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func2();", + Expected: []sql.Row{ + {"(1,username,t)"}, + {"(2,another,f)"}, + }, + }, + { + Query: "SELECT func2(), func2();", + Expected: []sql.Row{ + {"(1,username,t)", "(1,username,t)"}, + {"(2,another,f)", "(2,another,f)"}, + }, + }, + }, + }, + { + Name: "RETURNS TABLE with single field", + SetUpScript: []string{ + `CREATE FUNCTION func2() RETURNS TABLE(username text) + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT 'username1'; + RETURN QUERY SELECT 'username2'; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func2();", + Expected: []sql.Row{ + {"(username1)"}, + {"(username2)"}, + }, + }, + { + Query: "SELECT func2(), func2();", + Expected: []sql.Row{ + {"(username1)", "(username1)"}, + {"(username2)", "(username2)"}, + }, + }, + }, + }, + { + Name: "RETURNS TABLE with types from other schema", + SetUpScript: []string{ + `CREATE SCHEMA sch1;`, + `CREATE TYPE sch1.mytype AS ( + user_id integer, + username text);`, + `CREATE FUNCTION func2() RETURNS TABLE(foo sch1.mytype) + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT 1, 'username1'; + RETURN QUERY SELECT 2, 'username2'; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func2();", + Expected: []sql.Row{ + {"(1,username1)"}, + {"(2,username2)"}, + }, + }, + { + Query: "SELECT func2(), func2();", + Expected: []sql.Row{ + {"(1,username1)", "(1,username1)"}, + {"(2,username2)", "(2,username2)"}, + }, + }, + }, + }, + { + Name: "RETURNS TABLE with param", + SetUpScript: []string{ + `CREATE OR REPLACE FUNCTION func3(user_id integer) RETURNS TABLE(user_id integer, username text, is_active boolean) + LANGUAGE plpgsql + AS $$ + BEGIN + RETURN QUERY SELECT user_id, 'username', true; + RETURN QUERY SELECT user_id, 'another', false; + END; + $$;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT func3(111);", + Expected: []sql.Row{ + {"(111,username,t)"}, + {"(111,another,f)"}, + }, + }, + { + Query: "SELECT func3(111), func3(222);", + Expected: []sql.Row{ + {"(111,username,t)", "(222,username,t)"}, + {"(111,another,f)", "(222,another,f)"}, + }, + }, + }, + }, { Name: "RAISE", SetUpScript: []string{