Skip to content

Commit c662329

Browse files
committed
support enum type cast
1 parent 9d098e6 commit c662329

File tree

3 files changed

+64
-15
lines changed

3 files changed

+64
-15
lines changed

server/functions/enum.go

Lines changed: 60 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ import (
1919

2020
"github.com/dolthub/go-mysql-server/sql"
2121

22+
"github.com/dolthub/doltgresql/core"
23+
"github.com/dolthub/doltgresql/core/id"
2224
"github.com/dolthub/doltgresql/server/functions/framework"
2325
pgtypes "github.com/dolthub/doltgresql/server/types"
2426
"github.com/dolthub/doltgresql/utils"
@@ -40,10 +42,20 @@ var enum_in = framework.Function2{
4042
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Cstring, pgtypes.Oid},
4143
Strict: true,
4244
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) {
43-
// typOid := val2.(id.Internal)
44-
// TODO: get type using given OID, which should give access to enum labels.
45-
// should return the index of label?
46-
return val1.(string), nil
45+
typ, err := getDoltgresTypeFromInternal(ctx, val2.(id.Internal))
46+
if err != nil {
47+
return nil, err
48+
}
49+
if typ.TypCategory != pgtypes.TypeCategory_EnumTypes {
50+
return nil, fmt.Errorf(`"%s" is not an enum type`, typ.Name())
51+
}
52+
53+
value := val1.(string)
54+
if _, exists := typ.EnumLabels[value]; !exists {
55+
return nil, pgtypes.ErrInvalidInputValueForEnum.New(typ.Name(), value)
56+
}
57+
// TODO: should return the index instead of label?
58+
return value, nil
4759
},
4860
}
4961

@@ -54,7 +66,7 @@ var enum_out = framework.Function1{
5466
Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyEnum},
5567
Strict: true,
5668
Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) {
57-
// TODO: should return the index of label?
69+
// TODO: should receive the index instead of label?
5870
return val.(string), nil
5971
},
6072
}
@@ -66,15 +78,28 @@ var enum_recv = framework.Function2{
6678
Parameters: [2]*pgtypes.DoltgresType{pgtypes.Internal, pgtypes.Oid},
6779
Strict: true,
6880
Callable: func(ctx *sql.Context, _ [3]*pgtypes.DoltgresType, val1, val2 any) (any, error) {
69-
// typOid := val2.(id.Internal)
70-
// TODO: get type using given OID, which should give access to enum labels.
71-
// should return the index of label?
81+
// TODO: should return the index instead of label?
7282
data := val1.([]byte)
7383
if len(data) == 0 {
7484
return nil, nil
7585
}
7686
reader := utils.NewReader(data)
77-
return reader.String(), nil
87+
value := reader.String()
88+
if ctx == nil {
89+
// TODO: currently, in some places we use nil context, should fix it.
90+
return value, nil
91+
}
92+
typ, err := getDoltgresTypeFromInternal(ctx, val2.(id.Internal))
93+
if err != nil {
94+
return nil, err
95+
}
96+
if typ.TypCategory != pgtypes.TypeCategory_EnumTypes {
97+
return nil, fmt.Errorf(`"%s" is not an enum type`, typ.Name())
98+
}
99+
if _, exists := typ.EnumLabels[value]; !exists {
100+
return nil, pgtypes.ErrInvalidInputValueForEnum.New(typ.Name(), value)
101+
}
102+
return value, nil
78103
},
79104
}
80105

@@ -85,7 +110,7 @@ var enum_send = framework.Function1{
85110
Parameters: [1]*pgtypes.DoltgresType{pgtypes.AnyEnum},
86111
Strict: true,
87112
Callable: func(ctx *sql.Context, _ [2]*pgtypes.DoltgresType, val any) (any, error) {
88-
// TODO: should return the index of label?
113+
// TODO: should return the index instead of label?
89114
str := val.(string)
90115
writer := utils.NewWriter(uint64(len(str) + 4))
91116
writer.String(str)
@@ -123,3 +148,28 @@ var enum_cmp = framework.Function2{
123148
}
124149
},
125150
}
151+
152+
// getDoltgresTypeFromInternal takes internal ID and returns DoltgresType associated to it.
153+
// It allows retrieving user-defined type and requires valid sql.Context.
154+
func getDoltgresTypeFromInternal(ctx *sql.Context, typID id.Internal) (*pgtypes.DoltgresType, error) {
155+
typCol, err := core.GetTypesCollectionFromContext(ctx)
156+
if err != nil {
157+
return nil, err
158+
}
159+
160+
schName := typID.Segment(0)
161+
sch, err := core.GetCurrentSchema(ctx)
162+
if err != nil {
163+
return nil, err
164+
}
165+
if schName == "" {
166+
schName = sch
167+
}
168+
169+
typName := typID.Segment(1)
170+
typ, found := typCol.GetType(schName, typName)
171+
if !found {
172+
return nil, pgtypes.ErrTypeDoesNotExist.New(typName)
173+
}
174+
return typ, nil
175+
}

testing/go/functions_test.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,10 +2069,6 @@ func TestSelectFromFunctions(t *testing.T) {
20692069
Query: `SELECT * FROM array_to_string(ARRAY[37.89, 1.2], '_');`,
20702070
Expected: []sql.Row{{"37.89_1.2"}},
20712071
},
2072-
{
2073-
Query: `SELECT format_type(874938247, 20);`,
2074-
Expected: []sql.Row{{"???"}},
2075-
},
20762072
{
20772073
Query: `SELECT * from format_type(874938247, 20);`,
20782074
Expected: []sql.Row{{"???"}},

testing/go/types_test.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3292,6 +3292,10 @@ var enumTypeTests = []ScriptTest{
32923292
Query: `SELECT * FROM person WHERE current_mood > 'sad' ORDER BY current_mood;`,
32933293
Expected: []sql.Row{{"Curly", "ok"}, {"Moe", "happy"}},
32943294
},
3295+
{
3296+
Query: `INSERT INTO person VALUES ('Joey', 'invalid');`,
3297+
ExpectedErr: `invalid input value for enum mood: "invalid"`,
3298+
},
32953299
},
32963300
},
32973301
{
@@ -3324,7 +3328,6 @@ var enumTypeTests = []ScriptTest{
33243328
},
33253329
},
33263330
{
3327-
Skip: true,
33283331
Name: "enum type cast",
33293332
SetUpScript: []string{
33303333
`CREATE TYPE mood AS ENUM ('sad', 'ok', 'happy')`,

0 commit comments

Comments
 (0)