Skip to content

Commit 4f44a79

Browse files
authored
add new cast implementation for complex types (#6278)
Casting of complex types is implemented in runtime/sam/expr/shaper.go. Replace it with a new implementation in runtime/sam/expr/function/cast.go with these differences: * Always generates a structured error when casting fails * Preserves target type field order when casting records * Able to cast maps * More maintainable
1 parent b623ff4 commit 4f44a79

File tree

12 files changed

+310
-30
lines changed

12 files changed

+310
-30
lines changed

book/src/super-sql/functions/types/cast.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,8 @@ cast(this, <{b:string}>)
198198
{a:3}
199199
{b:4}
200200
# expected output
201-
{a:1,b:"2"}
202-
{a:3}
201+
{b:"2"}
202+
{b:null::string}
203203
{b:"4"}
204204
```
205205

compiler/rungen/vexpr.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,11 @@ func (b *Builder) compileVamCast(args []dag.Expr) (vamexpr.Evaluator, error) {
253253
return cast, nil
254254
}
255255
}
256-
return b.compileVamShaper(args, expr.Cast)
256+
e, err := b.compileCall(&dag.Call{Tag: "cast", Args: args})
257+
if err != nil {
258+
return nil, err
259+
}
260+
return vamexpr.NewSamExpr(e), nil
257261
}
258262

259263
func (b *Builder) compileVamShaper(args []dag.Expr, tf expr.ShaperTransform) (vamexpr.Evaluator, error) {

runtime/sam/expr/function/cast.go

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
package function
2+
3+
import (
4+
"maps"
5+
"slices"
6+
7+
"github.com/brimdata/super"
8+
"github.com/brimdata/super/runtime/sam/expr"
9+
"github.com/brimdata/super/scode"
10+
"github.com/brimdata/super/sup"
11+
)
12+
13+
type cast struct {
14+
sctx *super.Context
15+
}
16+
17+
func (c *cast) Call(args []super.Value) super.Value {
18+
from, to := args[0], args[1]
19+
if from.IsError() {
20+
return from
21+
}
22+
switch toUnder := to.Under(); toUnder.Type().ID() {
23+
case super.IDString:
24+
typ, err := c.sctx.LookupTypeNamed(toUnder.AsString(), super.TypeUnder(from.Type()))
25+
if err != nil {
26+
return c.sctx.WrapError("cannot cast to named type: "+err.Error(), from)
27+
}
28+
return super.NewValue(typ, from.Bytes())
29+
case super.IDType:
30+
typ, err := c.sctx.LookupByValue(toUnder.Bytes())
31+
if err != nil {
32+
panic(err)
33+
}
34+
return c.cast(from, typ)
35+
}
36+
return c.sctx.WrapError("cast target must be a type or type name", to)
37+
}
38+
39+
func (c *cast) cast(from super.Value, to super.Type) super.Value {
40+
if from.IsNull() {
41+
return super.NewValue(to, nil)
42+
}
43+
switch fromType := from.Type(); {
44+
case fromType == to:
45+
return from
46+
case fromType.ID() == to.ID():
47+
return super.NewValue(to, from.Bytes())
48+
}
49+
switch to := to.(type) {
50+
case *super.TypeRecord:
51+
return c.toRecord(from, to)
52+
case *super.TypeArray, *super.TypeSet:
53+
return c.toArrayOrSet(from, to)
54+
case *super.TypeMap:
55+
return c.toMap(from, to)
56+
case *super.TypeUnion:
57+
return c.toUnion(from, to)
58+
case *super.TypeError:
59+
return c.toError(from, to)
60+
case *super.TypeNamed:
61+
return c.toNamed(from, to)
62+
default:
63+
from = from.Under()
64+
if from.IsNull() {
65+
return super.NewValue(to, nil)
66+
}
67+
caster := expr.LookupPrimitiveCaster(c.sctx, to)
68+
if caster == nil {
69+
return c.error(from, to)
70+
}
71+
return caster.Eval(from)
72+
}
73+
}
74+
75+
func (c *cast) error(from super.Value, to super.Type) super.Value {
76+
return c.sctx.WrapError("cannot cast to "+sup.FormatType(to), from)
77+
}
78+
79+
func (c *cast) toRecord(from super.Value, to *super.TypeRecord) super.Value {
80+
from = from.Under()
81+
if !super.IsRecordType(from.Type()) {
82+
return c.error(from, to)
83+
}
84+
var b scode.Builder
85+
var fields []super.Field
86+
for i, f := range to.Fields {
87+
var val2 super.Value
88+
if fieldVal := from.Deref(f.Name); fieldVal != nil {
89+
val2 = c.cast(*fieldVal, f.Type)
90+
} else {
91+
val2 = super.NewValue(f.Type, nil)
92+
}
93+
if t := val2.Type(); t != f.Type {
94+
if fields == nil {
95+
fields = slices.Clone(to.Fields)
96+
}
97+
fields[i].Type = t
98+
}
99+
b.Append(val2.Bytes())
100+
}
101+
if fields != nil {
102+
to = c.sctx.MustLookupTypeRecord(fields)
103+
}
104+
return super.NewValue(to, b.Bytes())
105+
}
106+
107+
func (c *cast) toArrayOrSet(from super.Value, to super.Type) super.Value {
108+
from = from.Under()
109+
fromInner := super.InnerType(from.Type())
110+
toInner := super.InnerType(to)
111+
if fromInner == nil {
112+
// XXX Should also return an error if casting from fromInner to
113+
// toInner will always fail.
114+
return c.error(from, to)
115+
}
116+
types := map[super.Type]struct{}{}
117+
var vals []super.Value
118+
for it := from.Iter(); !it.Done(); {
119+
val := c.castNext(&it, fromInner, toInner)
120+
types[val.Type()] = struct{}{}
121+
vals = append(vals, val)
122+
}
123+
if len(vals) == 0 {
124+
return super.NewValue(to, from.Bytes())
125+
}
126+
inner := c.maybeConvertToUnion(vals, types)
127+
if inner != toInner {
128+
if to.Kind() == super.ArrayKind {
129+
to = c.sctx.LookupTypeArray(inner)
130+
} else {
131+
to = c.sctx.LookupTypeSet(inner)
132+
}
133+
}
134+
var bytes scode.Bytes
135+
for _, val := range vals {
136+
bytes = scode.Append(bytes, val.Bytes())
137+
}
138+
if to.Kind() == super.SetKind {
139+
bytes = super.NormalizeSet(bytes)
140+
}
141+
return super.NewValue(to, bytes)
142+
}
143+
144+
func (c *cast) castNext(it *scode.Iter, from, to super.Type) super.Value {
145+
val := super.NewValue(from, it.Next())
146+
return c.cast(val, to)
147+
}
148+
149+
func (c *cast) maybeConvertToUnion(vals []super.Value, types map[super.Type]struct{}) super.Type {
150+
typesSlice := slices.Collect(maps.Keys(types))
151+
if len(typesSlice) == 1 {
152+
return typesSlice[0]
153+
}
154+
union := c.sctx.LookupTypeUnion(typesSlice)
155+
for i, val := range vals {
156+
vals[i] = c.toUnion(val, union)
157+
}
158+
return union
159+
}
160+
161+
func (c *cast) toMap(from super.Value, to *super.TypeMap) super.Value {
162+
from = from.Under()
163+
fromType, ok := from.Type().(*super.TypeMap)
164+
if !ok {
165+
return c.error(from, to)
166+
}
167+
keyTypes := map[super.Type]struct{}{}
168+
valTypes := map[super.Type]struct{}{}
169+
var keyVals, valVals []super.Value
170+
for it := from.Iter(); !it.Done(); {
171+
keyVal := c.castNext(&it, fromType.KeyType, to.KeyType)
172+
keyVals = append(keyVals, keyVal)
173+
keyTypes[keyVal.Type()] = struct{}{}
174+
valVal := c.castNext(&it, fromType.ValType, to.ValType)
175+
valTypes[valVal.Type()] = struct{}{}
176+
valVals = append(valVals, valVal)
177+
}
178+
if len(keyVals) == 0 {
179+
return super.NewValue(to, from.Bytes())
180+
}
181+
keyType := c.maybeConvertToUnion(keyVals, keyTypes)
182+
valType := c.maybeConvertToUnion(valVals, valTypes)
183+
if keyType != to.KeyType || valType != to.ValType {
184+
to = c.sctx.LookupTypeMap(keyType, valType)
185+
}
186+
var bytes scode.Bytes
187+
for i := range keyVals {
188+
bytes = scode.Append(bytes, keyVals[i].Bytes())
189+
bytes = scode.Append(bytes, valVals[i].Bytes())
190+
}
191+
return super.NewValue(to, super.NormalizeMap(bytes))
192+
}
193+
194+
func (c *cast) toUnion(from super.Value, to *super.TypeUnion) super.Value {
195+
tag := expr.BestUnionTag(from.Type(), to)
196+
if tag < 0 {
197+
from2 := from.Deunion()
198+
tag = expr.BestUnionTag(from2.Type(), to)
199+
if tag < 0 {
200+
return c.error(from, to)
201+
}
202+
from = from2
203+
}
204+
bytes := from.Bytes()
205+
if bytes != nil {
206+
bytes = scode.Append(scode.Append(nil, super.EncodeInt(int64(tag))), bytes)
207+
}
208+
return super.NewValue(to, bytes)
209+
}
210+
211+
func (c *cast) toError(from super.Value, to *super.TypeError) super.Value {
212+
from = c.cast(from, to.Type)
213+
if from.Type() != to.Type {
214+
return from
215+
}
216+
return super.NewValue(to, from.Bytes())
217+
}
218+
219+
func (c *cast) toNamed(from super.Value, to *super.TypeNamed) super.Value {
220+
from = c.cast(from, to.Type)
221+
if from.Type() != to.Type {
222+
return from
223+
}
224+
return super.NewValue(to, from.Bytes())
225+
}

runtime/sam/expr/function/function.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ func New(sctx *super.Context, name string, narg int) (expr.Function, error) {
2828
argmin = 2
2929
argmax = 2
3030
f = &Bucket{sctx: sctx, name: name}
31+
case "cast":
32+
argmin = 2
33+
argmax = 2
34+
f = &cast{sctx}
3135
case "ceil":
3236
f = &Ceil{sctx: sctx}
3337
case "cidr_match":

runtime/sam/expr/functions_test.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package expr_test
22

33
import (
4-
"errors"
54
"testing"
65

76
"github.com/brimdata/super/runtime/sam/expr/function"
@@ -139,20 +138,22 @@ func TestLen(t *testing.T) {
139138
func TestCast(t *testing.T) {
140139
// Constant type argument
141140
testSuccessful(t, "cast(1, <uint64>)", "", "1::uint64")
142-
testError(t, "cast(1, 2)", errors.New("shaper type argument is not a type: 2"))
141+
testSuccessful(t, "cast(1, 2)", "", `error({message:"cast target must be a type or type name",on:2})`)
143142

144143
// Constant name argument
145144
testSuccessful(t, `cast(1, "my_int64")`, "", "1::=my_int64")
146-
testError(t, `cast(1, "uint64")`, errors.New(`bad type name "uint64": primitive type name`))
145+
testSuccessful(t, `cast(1, "uint64")`, "",
146+
`error({message:"cannot cast to named type: bad type name \"uint64\": primitive type name",on:1})`)
147147

148148
// Variable type argument
149149
testSuccessful(t, "cast(1, type)", "{type:<uint64>}", "1::uint64")
150-
testSuccessful(t, "cast(1, type)", "{type:2}", `error({message:"shaper type argument is not a type",on:2})`)
150+
testSuccessful(t, "cast(1, type)", "{type:2}",
151+
`error({message:"cast target must be a type or type name",on:2})`)
151152

152153
// Variable name argument
153154
testSuccessful(t, "cast(1, name)", `{name:"my_int64"}`, "1::=my_int64")
154-
testSuccessful(t, "cast(1, name)", `{name:"uint64"}`, `error("bad type name \"uint64\": primitive type name")`)
155-
155+
testSuccessful(t, "cast(1, name)", `{name:"uint64"}`,
156+
`error({message:"cannot cast to named type: bad type name \"uint64\": primitive type name",on:1})`)
156157
testCompilationError(t, "cast()", function.ErrTooFewArgs)
157158
testCompilationError(t, "cast(1, 2, 3)", function.ErrTooManyArgs)
158159
}

runtime/sam/expr/shaper.go

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ const (
2727

2828
func NewShaperTransform(s string) ShaperTransform {
2929
switch s {
30-
case "cast":
31-
return Cast
3230
case "crop":
3331
return Crop
3432
case "fill":
@@ -209,7 +207,7 @@ func shaperType(sctx *super.Context, tf ShaperTransform, in, out super.Type) (su
209207
}
210208
return out, nil
211209
}
212-
if bestUnionTag(in, outUnder) > -1 {
210+
if BestUnionTag(in, outUnder) > -1 {
213211
return out, nil
214212
}
215213
} else if inUnder == outUnder {
@@ -293,14 +291,14 @@ func shaperFields(sctx *super.Context, tf ShaperTransform, in, out *super.TypeRe
293291
return fields, nil
294292
}
295293

296-
// bestUnionTag tries to return the most specific union tag for in
294+
// BestUnionTag tries to return the most specific union tag for in
297295
// within out. It returns -1 if out is not a union or contains no type
298296
// compatible with in. (Types are compatible if they have the same underlying
299-
// type.) If out contains in, bestUnionTag returns its tag.
300-
// Otherwise, if out contains in's underlying type, bestUnionTag returns
301-
// its tag. Finally, bestUnionTag returns the smallest tag in
297+
// type.) If out contains in, BestUnionTag returns its tag.
298+
// Otherwise, if out contains in's underlying type, BestUnionTag returns
299+
// its tag. Finally, BestUnionTag returns the smallest tag in
302300
// out whose type is compatible with in.
303-
func bestUnionTag(in, out super.Type) int {
301+
func BestUnionTag(in, out super.Type) int {
304302
outUnion, ok := super.TypeUnder(out).(*super.TypeUnion)
305303
if !ok {
306304
return -1
@@ -395,7 +393,7 @@ Switch:
395393
}
396394
return step{op: castFromUnion, toType: out, children: steps}, nil
397395
}
398-
if tag := bestUnionTag(in, out); tag != -1 {
396+
if tag := BestUnionTag(in, out); tag != -1 {
399397
return step{op: castToUnion, fromType: in, toTag: tag, toType: out}, nil
400398
}
401399
return step{}, fmt.Errorf("createStep: incompatible types %s and %s", sup.FormatType(in), sup.FormatType(out))

runtime/sam/expr/shaper_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,14 @@ func TestBestUnionTag(t *testing.T) {
1818
u8named3, err := sctx.LookupTypeNamed("u8named3", u8)
1919
require.NoError(t, err)
2020

21-
assert.Equal(t, -1, bestUnionTag(u8, nil))
22-
assert.Equal(t, -1, bestUnionTag(u8, u8))
23-
assert.Equal(t, -1, bestUnionTag(super.TypeUint16, sctx.LookupTypeUnion([]super.Type{u8})))
21+
assert.Equal(t, -1, BestUnionTag(u8, nil))
22+
assert.Equal(t, -1, BestUnionTag(u8, u8))
23+
assert.Equal(t, -1, BestUnionTag(super.TypeUint16, sctx.LookupTypeUnion([]super.Type{u8})))
2424

2525
test := func(expected, needle super.Type, haystack []super.Type) {
2626
t.Helper()
2727
union := sctx.LookupTypeUnion(haystack)
28-
typ, err := union.Type(bestUnionTag(needle, union))
28+
typ, err := union.Type(BestUnionTag(needle, union))
2929
if assert.NoError(t, err) {
3030
assert.Equal(t, expected, typ)
3131
}

runtime/ztests/expr/cast/map.yaml

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
spq: |
2+
values this::|{int64:|{string:int64}|}|
3+
4+
vector: true
5+
6+
input: |
7+
|{}|
8+
|{1:|{"2":3,4:"5"}|,"7":8,"9":|{10:[],[]:11,null:null}|}|
9+
null
10+
1
11+
12+
output: |
13+
|{}|
14+
|{1:|{"2":3,"4":5}|,7:error({message:"cannot cast to |{string:int64}|",on:8}),9:|{null::string:null,"10":error({message:"cannot cast to int64",on:[]::[null]}),"[]":11}|}|
15+
null::|{int64:|{string:int64}|}|
16+
error({message:"cannot cast to |{int64:|{string:int64}|}|",on:1})

0 commit comments

Comments
 (0)