diff --git a/runtime/sam/expr/coerce/cast.go b/runtime/sam/expr/coerce/cast.go index e0afa9ad82..b12241049f 100644 --- a/runtime/sam/expr/coerce/cast.go +++ b/runtime/sam/expr/coerce/cast.go @@ -23,6 +23,12 @@ func ToUint(val super.Value, typUint super.Type) (uint64, bool) { v := val.Float() min, max, check := FromFloatOverflowCheck(val.Type(), typUint) return uint64(v), !check || v >= min && v <= max + case id == super.IDBool: + var v uint64 + if val.Bool() { + v = 1 + } + return v, true case id == super.IDString: v, err := strconv.ParseUint(val.AsString(), 10, UintBits(typUint)) return v, err == nil @@ -45,6 +51,12 @@ func ToInt(val super.Value, typInt super.Type) (int64, bool) { v := val.Float() min, max, check := FromFloatOverflowCheck(val.Type(), typInt) return int64(v), !check || v >= min && v <= max + case id == super.IDBool: + var v int64 + if val.Bool() { + v = 1 + } + return v, true case id == super.IDString: v, err := strconv.ParseInt(val.AsString(), 10, IntBits(typInt)) return v, err == nil @@ -63,6 +75,10 @@ func ToFloat(val super.Value, typ super.Type) (float64, bool) { v = float64(val.Int()) case super.IsFloat(fromId): v = val.Float() + case fromId == super.IDBool: + if val.Bool() { + v = 1 + } case fromId == super.IDString: var err error if v, err = byteconv.ParseFloat64(val.Bytes()); err != nil { diff --git a/runtime/vam/expr/cast/number.go b/runtime/vam/expr/cast/number.go index f269668d69..1fd0113011 100644 --- a/runtime/vam/expr/cast/number.go +++ b/runtime/vam/expr/cast/number.go @@ -79,6 +79,8 @@ func toNumeric[T numeric](vec vector.Any, typ super.Type, index []uint32) ([]T, return checkAndCastNumbers[float64, T](vec.Values, min, max, index) } return castNumbers[float64, T](vec.Values, index), nil + case *vector.Bool: + return boolToNumeric[T](vec, index), nil default: panic(vec) } @@ -123,6 +125,21 @@ func castNumbers[E numeric, T numeric](s []E, index []uint32) []T { return out } +func boolToNumeric[T numeric](vec *vector.Bool, index []uint32) []T { + n := lengthOf(vec, index) + out := make([]T, n) + for i := range n { + idx := i + if index != nil { + idx = index[i] + } + if vec.Bits.IsSet(idx) { + out[i] = 1 + } + } + return out +} + func castStringToNumber(vec vector.Any, typ super.Type, index []uint32) (vector.Any, []uint32) { svec := vec.(*vector.String) switch id := typ.ID(); { diff --git a/runtime/ztests/expr/cast/float.yaml b/runtime/ztests/expr/cast/float.yaml index 88a6939db6..fd708ecf41 100644 --- a/runtime/ztests/expr/cast/float.yaml +++ b/runtime/ztests/expr/cast/float.yaml @@ -9,6 +9,8 @@ input: | 1::uint64 1000000000::uint64 2::=named + false + true output: | 1.5::float16 @@ -29,3 +31,9 @@ output: | 2.::float16 2.::float32 2. + 0.::float16 + 0.::float32 + 0. + 1.::float16 + 1.::float32 + 1. diff --git a/runtime/ztests/expr/cast/int.yaml b/runtime/ztests/expr/cast/int.yaml index 06acbbbb5e..3356f49d20 100644 --- a/runtime/ztests/expr/cast/int.yaml +++ b/runtime/ztests/expr/cast/int.yaml @@ -12,7 +12,8 @@ input: | -1. 1.5::=named 1e8 - + false + true output: | -1::int8 @@ -47,3 +48,11 @@ output: | error({message:"cannot cast to int16",on:100000000.}) 100000000::int32 100000000 + 0::int8 + 0::int16 + 0::int32 + 0 + 1::int8 + 1::int16 + 1::int32 + 1 diff --git a/runtime/ztests/expr/cast/uint.yaml b/runtime/ztests/expr/cast/uint.yaml new file mode 100644 index 0000000000..de51f29590 --- /dev/null +++ b/runtime/ztests/expr/cast/uint.yaml @@ -0,0 +1,18 @@ +spq: | + values this::uint8, this::uint16, this::uint32, this::uint64 + +vector: true + +input: | + false + true + +output: | + 0::uint8 + 0::uint16 + 0::uint32 + 0::uint64 + 1::uint8 + 1::uint16 + 1::uint32 + 1::uint64