|
| 1 | +package function |
| 2 | + |
| 3 | +import ( |
| 4 | + "maps" |
| 5 | + "slices" |
| 6 | + |
| 7 | + "github.com/brimdata/super" |
| 8 | + "github.com/brimdata/super/runtime/sam/expr" |
| 9 | + "github.com/brimdata/super/scode" |
| 10 | + "github.com/brimdata/super/sup" |
| 11 | +) |
| 12 | + |
| 13 | +type cast struct { |
| 14 | + sctx *super.Context |
| 15 | +} |
| 16 | + |
| 17 | +func (c *cast) Call(args []super.Value) super.Value { |
| 18 | + from, to := args[0], args[1] |
| 19 | + if from.IsError() { |
| 20 | + return from |
| 21 | + } |
| 22 | + switch toUnder := to.Under(); toUnder.Type().ID() { |
| 23 | + case super.IDString: |
| 24 | + typ, err := c.sctx.LookupTypeNamed(toUnder.AsString(), super.TypeUnder(from.Type())) |
| 25 | + if err != nil { |
| 26 | + return c.sctx.WrapError("cannot cast to named type: "+err.Error(), from) |
| 27 | + } |
| 28 | + return super.NewValue(typ, from.Bytes()) |
| 29 | + case super.IDType: |
| 30 | + typ, err := c.sctx.LookupByValue(toUnder.Bytes()) |
| 31 | + if err != nil { |
| 32 | + panic(err) |
| 33 | + } |
| 34 | + return c.cast(from, typ) |
| 35 | + } |
| 36 | + return c.sctx.WrapError("cast target must be a type or type name", to) |
| 37 | +} |
| 38 | + |
| 39 | +func (c *cast) cast(from super.Value, to super.Type) super.Value { |
| 40 | + if from.IsNull() { |
| 41 | + return super.NewValue(to, nil) |
| 42 | + } |
| 43 | + switch fromType := from.Type(); { |
| 44 | + case fromType == to: |
| 45 | + return from |
| 46 | + case fromType.ID() == to.ID(): |
| 47 | + return super.NewValue(to, from.Bytes()) |
| 48 | + } |
| 49 | + switch to := to.(type) { |
| 50 | + case *super.TypeRecord: |
| 51 | + return c.toRecord(from, to) |
| 52 | + case *super.TypeArray, *super.TypeSet: |
| 53 | + return c.toArrayOrSet(from, to) |
| 54 | + case *super.TypeMap: |
| 55 | + return c.toMap(from, to) |
| 56 | + case *super.TypeUnion: |
| 57 | + return c.toUnion(from, to) |
| 58 | + case *super.TypeError: |
| 59 | + return c.toError(from, to) |
| 60 | + case *super.TypeNamed: |
| 61 | + return c.toNamed(from, to) |
| 62 | + default: |
| 63 | + from = from.Under() |
| 64 | + if from.IsNull() { |
| 65 | + return super.NewValue(to, nil) |
| 66 | + } |
| 67 | + caster := expr.LookupPrimitiveCaster(c.sctx, to) |
| 68 | + if caster == nil { |
| 69 | + return c.error(from, to) |
| 70 | + } |
| 71 | + return caster.Eval(from) |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +func (c *cast) error(from super.Value, to super.Type) super.Value { |
| 76 | + return c.sctx.WrapError("cannot cast to "+sup.FormatType(to), from) |
| 77 | +} |
| 78 | + |
| 79 | +func (c *cast) toRecord(from super.Value, to *super.TypeRecord) super.Value { |
| 80 | + from = from.Under() |
| 81 | + if !super.IsRecordType(from.Type()) { |
| 82 | + return c.error(from, to) |
| 83 | + } |
| 84 | + var b scode.Builder |
| 85 | + var fields []super.Field |
| 86 | + for i, f := range to.Fields { |
| 87 | + var val2 super.Value |
| 88 | + if fieldVal := from.Deref(f.Name); fieldVal != nil { |
| 89 | + val2 = c.cast(*fieldVal, f.Type) |
| 90 | + } else { |
| 91 | + val2 = super.NewValue(f.Type, nil) |
| 92 | + } |
| 93 | + if t := val2.Type(); t != f.Type { |
| 94 | + if fields == nil { |
| 95 | + fields = slices.Clone(to.Fields) |
| 96 | + } |
| 97 | + fields[i].Type = t |
| 98 | + } |
| 99 | + b.Append(val2.Bytes()) |
| 100 | + } |
| 101 | + if fields != nil { |
| 102 | + to = c.sctx.MustLookupTypeRecord(fields) |
| 103 | + } |
| 104 | + return super.NewValue(to, b.Bytes()) |
| 105 | +} |
| 106 | + |
| 107 | +func (c *cast) toArrayOrSet(from super.Value, to super.Type) super.Value { |
| 108 | + from = from.Under() |
| 109 | + fromInner := super.InnerType(from.Type()) |
| 110 | + toInner := super.InnerType(to) |
| 111 | + if fromInner == nil { |
| 112 | + // XXX Should also return an error if casting from fromInner to |
| 113 | + // toInner will always fail. |
| 114 | + return c.error(from, to) |
| 115 | + } |
| 116 | + types := map[super.Type]struct{}{} |
| 117 | + var vals []super.Value |
| 118 | + for it := from.Iter(); !it.Done(); { |
| 119 | + val := c.castNext(&it, fromInner, toInner) |
| 120 | + types[val.Type()] = struct{}{} |
| 121 | + vals = append(vals, val) |
| 122 | + } |
| 123 | + if len(vals) == 0 { |
| 124 | + return super.NewValue(to, from.Bytes()) |
| 125 | + } |
| 126 | + inner := c.maybeConvertToUnion(vals, types) |
| 127 | + if inner != toInner { |
| 128 | + if to.Kind() == super.ArrayKind { |
| 129 | + to = c.sctx.LookupTypeArray(inner) |
| 130 | + } else { |
| 131 | + to = c.sctx.LookupTypeSet(inner) |
| 132 | + } |
| 133 | + } |
| 134 | + var bytes scode.Bytes |
| 135 | + for _, val := range vals { |
| 136 | + bytes = scode.Append(bytes, val.Bytes()) |
| 137 | + } |
| 138 | + if to.Kind() == super.SetKind { |
| 139 | + bytes = super.NormalizeSet(bytes) |
| 140 | + } |
| 141 | + return super.NewValue(to, bytes) |
| 142 | +} |
| 143 | + |
| 144 | +func (c *cast) castNext(it *scode.Iter, from, to super.Type) super.Value { |
| 145 | + val := super.NewValue(from, it.Next()) |
| 146 | + return c.cast(val, to) |
| 147 | +} |
| 148 | + |
| 149 | +func (c *cast) maybeConvertToUnion(vals []super.Value, types map[super.Type]struct{}) super.Type { |
| 150 | + typesSlice := slices.Collect(maps.Keys(types)) |
| 151 | + if len(typesSlice) == 1 { |
| 152 | + return typesSlice[0] |
| 153 | + } |
| 154 | + union := c.sctx.LookupTypeUnion(typesSlice) |
| 155 | + for i, val := range vals { |
| 156 | + vals[i] = c.toUnion(val, union) |
| 157 | + } |
| 158 | + return union |
| 159 | +} |
| 160 | + |
| 161 | +func (c *cast) toMap(from super.Value, to *super.TypeMap) super.Value { |
| 162 | + from = from.Under() |
| 163 | + fromType, ok := from.Type().(*super.TypeMap) |
| 164 | + if !ok { |
| 165 | + return c.error(from, to) |
| 166 | + } |
| 167 | + keyTypes := map[super.Type]struct{}{} |
| 168 | + valTypes := map[super.Type]struct{}{} |
| 169 | + var keyVals, valVals []super.Value |
| 170 | + for it := from.Iter(); !it.Done(); { |
| 171 | + keyVal := c.castNext(&it, fromType.KeyType, to.KeyType) |
| 172 | + keyVals = append(keyVals, keyVal) |
| 173 | + keyTypes[keyVal.Type()] = struct{}{} |
| 174 | + valVal := c.castNext(&it, fromType.ValType, to.ValType) |
| 175 | + valTypes[valVal.Type()] = struct{}{} |
| 176 | + valVals = append(valVals, valVal) |
| 177 | + } |
| 178 | + if len(keyVals) == 0 { |
| 179 | + return super.NewValue(to, from.Bytes()) |
| 180 | + } |
| 181 | + keyType := c.maybeConvertToUnion(keyVals, keyTypes) |
| 182 | + valType := c.maybeConvertToUnion(valVals, valTypes) |
| 183 | + if keyType != to.KeyType || valType != to.ValType { |
| 184 | + to = c.sctx.LookupTypeMap(keyType, valType) |
| 185 | + } |
| 186 | + var bytes scode.Bytes |
| 187 | + for i := range keyVals { |
| 188 | + bytes = scode.Append(bytes, keyVals[i].Bytes()) |
| 189 | + bytes = scode.Append(bytes, valVals[i].Bytes()) |
| 190 | + } |
| 191 | + return super.NewValue(to, super.NormalizeMap(bytes)) |
| 192 | +} |
| 193 | + |
| 194 | +func (c *cast) toUnion(from super.Value, to *super.TypeUnion) super.Value { |
| 195 | + tag := expr.BestUnionTag(from.Type(), to) |
| 196 | + if tag < 0 { |
| 197 | + from2 := from.Deunion() |
| 198 | + tag = expr.BestUnionTag(from2.Type(), to) |
| 199 | + if tag < 0 { |
| 200 | + return c.error(from, to) |
| 201 | + } |
| 202 | + from = from2 |
| 203 | + } |
| 204 | + bytes := from.Bytes() |
| 205 | + if bytes != nil { |
| 206 | + bytes = scode.Append(scode.Append(nil, super.EncodeInt(int64(tag))), bytes) |
| 207 | + } |
| 208 | + return super.NewValue(to, bytes) |
| 209 | +} |
| 210 | + |
| 211 | +func (c *cast) toError(from super.Value, to *super.TypeError) super.Value { |
| 212 | + from = c.cast(from, to.Type) |
| 213 | + if from.Type() != to.Type { |
| 214 | + return from |
| 215 | + } |
| 216 | + return super.NewValue(to, from.Bytes()) |
| 217 | +} |
| 218 | + |
| 219 | +func (c *cast) toNamed(from super.Value, to *super.TypeNamed) super.Value { |
| 220 | + from = c.cast(from, to.Type) |
| 221 | + if from.Type() != to.Type { |
| 222 | + return from |
| 223 | + } |
| 224 | + return super.NewValue(to, from.Bytes()) |
| 225 | +} |
0 commit comments