@@ -19,6 +19,8 @@ import (
1919
2020 "github.com/dolthub/go-mysql-server/sql"
2121
22+ "github.com/dolthub/doltgresql/core"
23+ "github.com/dolthub/doltgresql/core/id"
2224 "github.com/dolthub/doltgresql/server/functions/framework"
2325 pgtypes "github.com/dolthub/doltgresql/server/types"
2426 "github.com/dolthub/doltgresql/utils"
@@ -40,10 +42,20 @@ var enum_in = framework.Function2{
4042 Parameters : [2 ]* pgtypes.DoltgresType {pgtypes .Cstring , pgtypes .Oid },
4143 Strict : true ,
4244 Callable : func (ctx * sql.Context , _ [3 ]* pgtypes.DoltgresType , val1 , val2 any ) (any , error ) {
43- // typOid := val2.(id.Internal)
44- // TODO: get type using given OID, which should give access to enum labels.
45- // should return the index of label?
46- return val1 .(string ), nil
45+ typ , err := getDoltgresTypeFromInternal (ctx , val2 .(id.Internal ))
46+ if err != nil {
47+ return nil , err
48+ }
49+ if typ .TypCategory != pgtypes .TypeCategory_EnumTypes {
50+ return nil , fmt .Errorf (`"%s" is not an enum type` , typ .Name ())
51+ }
52+
53+ value := val1 .(string )
54+ if _ , exists := typ .EnumLabels [value ]; ! exists {
55+ return nil , pgtypes .ErrInvalidInputValueForEnum .New (typ .Name (), value )
56+ }
57+ // TODO: should return the index instead of label?
58+ return value , nil
4759 },
4860}
4961
@@ -54,7 +66,7 @@ var enum_out = framework.Function1{
5466 Parameters : [1 ]* pgtypes.DoltgresType {pgtypes .AnyEnum },
5567 Strict : true ,
5668 Callable : func (ctx * sql.Context , _ [2 ]* pgtypes.DoltgresType , val any ) (any , error ) {
57- // TODO: should return the index of label?
69+ // TODO: should receive the index instead of label?
5870 return val .(string ), nil
5971 },
6072}
@@ -66,15 +78,28 @@ var enum_recv = framework.Function2{
6678 Parameters : [2 ]* pgtypes.DoltgresType {pgtypes .Internal , pgtypes .Oid },
6779 Strict : true ,
6880 Callable : func (ctx * sql.Context , _ [3 ]* pgtypes.DoltgresType , val1 , val2 any ) (any , error ) {
69- // typOid := val2.(id.Internal)
70- // TODO: get type using given OID, which should give access to enum labels.
71- // should return the index of label?
81+ // TODO: should return the index instead of label?
7282 data := val1 .([]byte )
7383 if len (data ) == 0 {
7484 return nil , nil
7585 }
7686 reader := utils .NewReader (data )
77- return reader .String (), nil
87+ value := reader .String ()
88+ if ctx == nil {
89+ // TODO: currently, in some places we use nil context, should fix it.
90+ return value , nil
91+ }
92+ typ , err := getDoltgresTypeFromInternal (ctx , val2 .(id.Internal ))
93+ if err != nil {
94+ return nil , err
95+ }
96+ if typ .TypCategory != pgtypes .TypeCategory_EnumTypes {
97+ return nil , fmt .Errorf (`"%s" is not an enum type` , typ .Name ())
98+ }
99+ if _ , exists := typ .EnumLabels [value ]; ! exists {
100+ return nil , pgtypes .ErrInvalidInputValueForEnum .New (typ .Name (), value )
101+ }
102+ return value , nil
78103 },
79104}
80105
@@ -85,7 +110,7 @@ var enum_send = framework.Function1{
85110 Parameters : [1 ]* pgtypes.DoltgresType {pgtypes .AnyEnum },
86111 Strict : true ,
87112 Callable : func (ctx * sql.Context , _ [2 ]* pgtypes.DoltgresType , val any ) (any , error ) {
88- // TODO: should return the index of label?
113+ // TODO: should return the index instead of label?
89114 str := val .(string )
90115 writer := utils .NewWriter (uint64 (len (str ) + 4 ))
91116 writer .String (str )
@@ -123,3 +148,28 @@ var enum_cmp = framework.Function2{
123148 }
124149 },
125150}
151+
152+ // getDoltgresTypeFromInternal takes internal ID and returns DoltgresType associated to it.
153+ // It allows retrieving user-defined type and requires valid sql.Context.
154+ func getDoltgresTypeFromInternal (ctx * sql.Context , typID id.Internal ) (* pgtypes.DoltgresType , error ) {
155+ typCol , err := core .GetTypesCollectionFromContext (ctx )
156+ if err != nil {
157+ return nil , err
158+ }
159+
160+ schName := typID .Segment (0 )
161+ sch , err := core .GetCurrentSchema (ctx )
162+ if err != nil {
163+ return nil , err
164+ }
165+ if schName == "" {
166+ schName = sch
167+ }
168+
169+ typName := typID .Segment (1 )
170+ typ , found := typCol .GetType (schName , typName )
171+ if ! found {
172+ return nil , pgtypes .ErrTypeDoesNotExist .New (typName )
173+ }
174+ return typ , nil
175+ }
0 commit comments