@@ -25,7 +25,6 @@ import (
2525 "github.com/dolthub/doltgresql/core/id"
2626 "github.com/dolthub/doltgresql/postgres/parser/parser"
2727 "github.com/dolthub/doltgresql/postgres/parser/sem/tree"
28- "github.com/dolthub/doltgresql/postgres/parser/types"
2928 "github.com/dolthub/doltgresql/server/functions/framework"
3029 pgnodes "github.com/dolthub/doltgresql/server/node"
3130 "github.com/dolthub/doltgresql/server/plpgsql"
@@ -43,21 +42,11 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
4342 var retType * pgtypes.DoltgresType
4443 if len (node .RetType ) == 0 {
4544 retType = pgtypes .Void
46- } else if ! node .ReturnsTable { // Return types may specify "trigger", but this doesn't apply elsewhere
47- switch typ := node .RetType [0 ].Type .(type ) {
48- case * types.T :
49- retType = pgtypes .NewUnresolvedDoltgresType ("" , strings .ToLower (typ .Name ()))
50- case * tree.UnresolvedObjectName :
51- if typ .NumParts == 1 && typ .SQLString () == "trigger" {
52- retType = pgtypes .Trigger
53- } else {
54- _ , retType , err = nodeResolvableTypeReference (ctx , typ )
55- if err != nil {
56- return nil , err
57- }
58- }
59- default :
60- return nil , fmt .Errorf ("unsupported ResolvableTypeReference type: %T" , typ )
45+ } else if ! node .ReturnsTable {
46+ // Return types may specify "trigger", but this doesn't apply elsewhere
47+ _ , retType , err = nodeResolvableTypeReference (ctx , node .RetType [0 ].Type , true )
48+ if err != nil {
49+ return nil , err
6150 }
6251 } else {
6352 retType = createAnonymousCompositeType (node .RetType )
@@ -67,16 +56,9 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
6756 paramTypes := make ([]* pgtypes.DoltgresType , len (node .Args ))
6857 for i , arg := range node .Args {
6958 paramNames [i ] = arg .Name .String ()
70- switch argType := arg .Type .(type ) {
71- case * types.T :
72- paramTypes [i ] = pgtypes .NewUnresolvedDoltgresType ("" , strings .ToLower (argType .Name ()))
73- case * tree.UnresolvedObjectName :
74- _ , paramTypes [i ], err = nodeResolvableTypeReference (ctx , argType )
75- if err != nil {
76- return nil , err
77- }
78- default :
79- paramTypes [i ] = pgtypes .NewUnresolvedDoltgresType ("" , strings .ToLower (argType .SQLString ()))
59+ _ , paramTypes [i ], err = nodeResolvableTypeReference (ctx , arg .Type , false )
60+ if err != nil {
61+ return nil , err
8062 }
8163 }
8264 var strict bool
@@ -99,6 +81,23 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme
9981 if err != nil {
10082 return nil , err
10183 }
84+ // parse types
85+ for i , op := range parsedBody {
86+ switch op .OpCode {
87+ case plpgsql .OpCode_Declare :
88+ // ParseType uses casting to parse the given type, but
89+ // some special types cannot be cast. Eg: `user_defined_table_type%ROWTYPE`
90+ if declareTyp , err := parser .ParseType (op .PrimaryData ); err == nil {
91+ if _ , dt , err := nodeResolvableTypeReference (ctx , declareTyp , false ); err == nil && dt != nil {
92+ dtName := dt .Name ()
93+ if dt .Schema () != "" {
94+ dtName = fmt .Sprintf ("%s.%s" , dt .Schema (), dtName )
95+ }
96+ parsedBody [i ].PrimaryData = dtName
97+ }
98+ }
99+ }
100+ }
102101 case "sql" :
103102 as , ok := options [tree .OptionAs1 ]
104103 if ! ok {
0 commit comments