Skip to content

Commit bfff763

Browse files
committed
Add BIT support
This adds support for reading/writing the BIT type. Since there is no native type for representing bitstrings in Go, this adds a custom Go type that represents one. It uses the same internal storage as DuckDB so that conversions are cheap and require minimal copying. P.S. Strictly speaking there is the BitString type in the encoding/asn1, but that's not meant for general purpose use. It's really only there so that X509 certificates can be decoded. It's also laid out very differently than DuckDB its BIT type (e.g. it uses 0's for padding instead of 1's).
1 parent 1a6cd82 commit bfff763

File tree

15 files changed

+302
-22
lines changed

15 files changed

+302
-22
lines changed

appender_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,57 @@ func TestAppenderTimeTZ(t *testing.T) {
793793
require.Equal(t, expected, r)
794794
}
795795

796+
func TestAppenderBit(t *testing.T) {
797+
c, db, conn, a := prepareAppender(t, `CREATE TABLE test (data BIT)`)
798+
defer cleanupAppender(t, c, db, conn, a)
799+
800+
expected := []string{
801+
"10101",
802+
"11110000",
803+
"1",
804+
"111100001100110011",
805+
"00000000000000000000000010110101",
806+
}
807+
for _, bits := range expected {
808+
bitVal, err := NewBitFromString(bits)
809+
require.NoError(t, err)
810+
require.NoError(t, a.AppendRow(*bitVal))
811+
}
812+
813+
require.NoError(t, a.Flush())
814+
815+
// Verify results.
816+
res, err := db.QueryContext(context.Background(), `SELECT data FROM test`)
817+
require.NoError(t, err)
818+
defer closeRowsWrapper(t, res)
819+
820+
i := 0
821+
for res.Next() {
822+
var b Bit
823+
require.NoError(t, res.Scan(&b))
824+
require.Equal(t, expected[i], b.String())
825+
i++
826+
}
827+
require.Equal(t, len(expected), i)
828+
}
829+
830+
func TestAppenderNullBit(t *testing.T) {
831+
c, db, conn, a := prepareAppender(t, `CREATE TABLE test (b BIT)`)
832+
defer cleanupAppender(t, c, db, conn, a)
833+
834+
// Append a nil *Bit.
835+
var nilBit *Bit
836+
require.NoError(t, a.AppendRow(nilBit))
837+
require.NoError(t, a.Flush())
838+
839+
// Verify results.
840+
res := db.QueryRowContext(context.Background(), `SELECT b FROM test`)
841+
842+
var r *Bit
843+
require.NoError(t, res.Scan(&r))
844+
require.Nil(t, r)
845+
}
846+
796847
func TestAppenderBlob(t *testing.T) {
797848
c, db, conn, a := prepareAppender(t, `CREATE TABLE test (data BLOB)`)
798849
defer cleanupAppender(t, c, db, conn, a)

connection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func newConn(conn mapping.Connection, ctxStore *contextStore) *Conn {
3737
// CheckNamedValue implements the driver.NamedValueChecker interface.
3838
func (conn *Conn) CheckNamedValue(nv *driver.NamedValue) error {
3939
switch nv.Value.(type) {
40-
case *big.Int, Interval, []any, []bool, []int8, []int16, []int32, []int64, []int, []uint8, []uint16,
40+
case *big.Int, Interval, Bit, []any, []bool, []int8, []int16, []int32, []int64, []int, []uint8, []uint16,
4141
[]uint32, []uint64, []uint, []float32, []float64, []string, map[string]any:
4242
return nil
4343
}

errors_test.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -150,23 +150,6 @@ func TestErrAppender(t *testing.T) {
150150
testError(t, err, errAppenderDoubleClose.Error())
151151
})
152152

153-
t.Run(unsupportedTypeErrMsg, func(t *testing.T) {
154-
c := newConnectorWrapper(t, ``, nil)
155-
defer closeConnectorWrapper(t, c)
156-
157-
db := sql.OpenDB(c)
158-
_, err := db.Exec(`CREATE TABLE test (bit_col BIT)`)
159-
require.NoError(t, err)
160-
defer closeDbWrapper(t, db)
161-
162-
conn := openDriverConnWrapper(t, c)
163-
defer closeDriverConnWrapper(t, &conn)
164-
165-
a, err := NewAppenderFromConn(conn, "", "test")
166-
defer closeAppenderWrapper(t, a)
167-
testError(t, err, errAppenderCreation.Error(), unsupportedTypeErrMsg)
168-
})
169-
170153
t.Run(columnCountErrMsg, func(t *testing.T) {
171154
c, db, conn, a := prepareAppender(t, `CREATE TABLE test (a VARCHAR, b VARCHAR)`)
172155
defer cleanupAppender(t, c, db, conn, a)

mapping/mapping.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ var (
223223
TimestampStructMembers = bindings.TimestampStructMembers
224224
NewInterval = bindings.NewInterval
225225
IntervalMembers = bindings.IntervalMembers
226+
NewBit = bindings.NewBit
227+
BitMembers = bindings.BitMembers
226228
NewHugeInt = bindings.NewHugeInt
227229
HugeIntMembers = bindings.HugeIntMembers
228230
NewUHugeInt = bindings.NewUHugeInt

rows.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ func (r *rows) getScanType(logicalType mapping.LogicalType, index mapping.IdxT)
132132
return reflectTypeString
133133
case TYPE_BLOB:
134134
return reflectTypeBytes
135+
case TYPE_BIT:
136+
return reflectTypeBit
135137
case TYPE_DECIMAL:
136138
return reflectTypeDecimal
137139
case TYPE_LIST:

statement.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,15 @@ func (s *Stmt) bindUUID(val driver.NamedValue, n int) (mapping.State, error) {
284284
return mapping.StateError, addIndexToError(unsupportedTypeError(unknownTypeErrMsg), n+1)
285285
}
286286

287+
func (s *Stmt) bindBit(val *Bit, n int) (mapping.State, error) {
288+
bit := mapping.NewBit(val.Data)
289+
defer mapping.DestroyBit(&bit)
290+
v := mapping.CreateBit(bit)
291+
defer mapping.DestroyValue(&v)
292+
state := mapping.BindValue(*s.preparedStmt, mapping.IdxT(n+1), v)
293+
return state, nil
294+
}
295+
287296
// Used for binding Array, List, Struct. In the future, also Map and Union
288297
func (s *Stmt) bindCompositeValue(val driver.NamedValue, n int) (mapping.State, error) {
289298
lt, err := s.paramLogicalType(n + 1)
@@ -339,6 +348,7 @@ func (s *Stmt) bindComplexValue(val driver.NamedValue, n int, t Type, name strin
339348
return mapping.StateError, addIndexToError(unsupportedTypeError(unknownTypeErrMsg), n+1)
340349
}
341350

351+
//nolint:gocyclo
342352
func (s *Stmt) bindValue(val driver.NamedValue, n int) (mapping.State, error) {
343353
// For some queries, we cannot resolve the parameter type when preparing the query.
344354
// E.g., for "SELECT * FROM (VALUES (?, ?)) t(a, b)", we cannot know the parameter types from the SQL statement alone.
@@ -422,6 +432,8 @@ func (s *Stmt) bindValue(val driver.NamedValue, n int) (mapping.State, error) {
422432
return mapping.StateError, inferErr
423433
}
424434
return mapping.BindInterval(*s.preparedStmt, mapping.IdxT(n+1), i), nil
435+
case Bit:
436+
return s.bindBit(&v, n)
425437
case nil:
426438
return mapping.BindNull(*s.preparedStmt, mapping.IdxT(n+1)), nil
427439
}

type.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ const (
5151
// FIXME: Implement support for these types.
5252
var unsupportedTypeToStringMap = map[Type]string{
5353
TYPE_INVALID: "INVALID",
54-
TYPE_BIT: "BIT",
5554
TYPE_ANY: "ANY",
5655
TYPE_BIGNUM: "BIGNUM",
5756
}

type_info.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ func (info *typeInfo) Details() TypeDetails {
211211
// Valid types are:
212212
// TYPE_[BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, UTINYINT, USMALLINT, UINTEGER,
213213
// UBIGINT, FLOAT, DOUBLE, TIMESTAMP, DATE, TIME, INTERVAL, HUGEINT, UHUGEINT, VARCHAR,
214-
// BLOB, TIMESTAMP_S, TIMESTAMP_MS, TIMESTAMP_NS, UUID, TIMESTAMP_TZ, ANY].
214+
// BLOB, BIT, TIMESTAMP_S, TIMESTAMP_MS, TIMESTAMP_NS, UUID, TIMESTAMP_TZ, ANY].
215215
func NewTypeInfo(t Type) (TypeInfo, error) {
216216
name, inMap := unsupportedTypeToStringMap[t]
217217
if inMap && t != TYPE_ANY {
@@ -418,7 +418,7 @@ func (info *typeInfo) logicalType() mapping.LogicalType {
418418
case TYPE_BOOLEAN, TYPE_TINYINT, TYPE_SMALLINT, TYPE_INTEGER, TYPE_BIGINT, TYPE_UTINYINT, TYPE_USMALLINT,
419419
TYPE_UINTEGER, TYPE_UBIGINT, TYPE_FLOAT, TYPE_DOUBLE, TYPE_TIMESTAMP, TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS,
420420
TYPE_TIMESTAMP_NS, TYPE_TIMESTAMP_TZ, TYPE_DATE, TYPE_TIME, TYPE_TIME_TZ, TYPE_INTERVAL, TYPE_HUGEINT,
421-
TYPE_UHUGEINT, TYPE_VARCHAR, TYPE_BLOB, TYPE_UUID, TYPE_ANY:
421+
TYPE_UHUGEINT, TYPE_VARCHAR, TYPE_BLOB, TYPE_BIT, TYPE_UUID, TYPE_ANY:
422422
return mapping.CreateLogicalType(info.Type)
423423
case TYPE_DECIMAL:
424424
return mapping.CreateDecimalType(info.decimalWidth, info.decimalScale)

type_info_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ var testPrimitiveSQLValues = map[Type]testTypeValues{
3838
TYPE_UHUGEINT: {input: `45::UHUGEINT`, output: `45`},
3939
TYPE_VARCHAR: {input: `'hello world'::VARCHAR`, output: `hello world`},
4040
TYPE_BLOB: {input: `'\xAA'::BLOB`, output: `[170]`},
41+
TYPE_BIT: {input: `'10101'::BIT`, output: `10101`},
4142
TYPE_TIMESTAMP_S: {input: `TIMESTAMP_S '1992-09-20 11:30:00'`, output: `1992-09-20 11:30:00 +0000 UTC`},
4243
TYPE_TIMESTAMP_MS: {input: `TIMESTAMP_MS '1992-09-20 11:30:00.123'`, output: `1992-09-20 11:30:00.123 +0000 UTC`},
4344
TYPE_TIMESTAMP_NS: {input: `TIMESTAMP_NS '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123456789 +0000 UTC`},

types.go

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
)
1818

1919
// duckdb-go exports the following type wrappers:
20-
// UUID, Map, Interval, Decimal, Union, Composite (optional, used to scan LIST and STRUCT).
20+
// UUID, Bit, Map, Interval, Decimal, Union, Composite (optional, used to scan LIST and STRUCT).
2121

2222
// Pre-computed reflect type values to avoid repeated allocations.
2323
var (
@@ -44,6 +44,7 @@ var (
4444
reflectTypeUnion = reflect.TypeFor[Union]()
4545
reflectTypeAny = reflect.TypeFor[any]()
4646
reflectTypeUUID = reflect.TypeFor[UUID]()
47+
reflectTypeBit = reflect.TypeFor[Bit]()
4748
reflectTypeHugeInt = reflect.TypeFor[mapping.HugeInt]()
4849
)
4950

@@ -137,6 +138,97 @@ func uuidToHugeInt(uuid UUID) mapping.HugeInt {
137138
return mapping.NewHugeInt(lower, int64(upper^(1<<63)))
138139
}
139140

141+
// Bit represents a DuckDB BIT value as a sequence of bits.
142+
// Data stores DuckDB's internal format: a padding-count prefix byte followed by
143+
// the bit bytes (right-aligned with 1-padded MSB bits).
144+
// For example, "10101" (5 bits) is stored as [3, 11110101] where 3 is the padding count.
145+
type Bit struct {
146+
Data []byte
147+
}
148+
149+
// NewBitFromString creates a Bit from a string of '0' and '1' characters.
150+
func NewBitFromString(s string) (*Bit, error) {
151+
if len(s) == 0 {
152+
return nil, fmt.Errorf("empty bit string")
153+
}
154+
155+
numBytes := (len(s) + 7) / 8
156+
padding := (8 - (len(s) % 8)) % 8
157+
data := make([]byte, numBytes+1)
158+
data[0] = byte(padding)
159+
160+
// Set padding bits to 1
161+
if padding > 0 {
162+
data[1] = byte(0xFF) << (8 - padding)
163+
}
164+
165+
for i, c := range s {
166+
switch c {
167+
case '1':
168+
bitPos := padding + i
169+
byteIdx := bitPos/8 + 1
170+
bitIdx := 7 - (bitPos % 8)
171+
data[byteIdx] |= 1 << bitIdx
172+
case '0':
173+
default:
174+
return nil, fmt.Errorf("invalid character in bit string: %c", c)
175+
}
176+
}
177+
178+
return &Bit{Data: data}, nil
179+
}
180+
181+
// Validate checks that Data is a valid DuckDB bit encoding: the padding count
182+
// (first byte) must be 0-7, and the padding bits in the first data byte must
183+
// all be set to 1.
184+
func (b Bit) Validate() error {
185+
if len(b.Data) <= 1 {
186+
return nil
187+
}
188+
padding := int(b.Data[0])
189+
if padding > 7 {
190+
return fmt.Errorf("invalid padding count %d, must be 0-7", padding)
191+
}
192+
if padding > 0 {
193+
expectedMask := byte(0xFF) << (8 - padding)
194+
if (b.Data[1] & expectedMask) != expectedMask {
195+
return fmt.Errorf("padding bits must be 1s, expected high %d bits of first byte to be set", padding)
196+
}
197+
}
198+
return nil
199+
}
200+
201+
// Len returns the number of bits.
202+
func (b Bit) Len() int {
203+
if len(b.Data) == 0 {
204+
return 0
205+
}
206+
return (len(b.Data)-1)*8 - int(b.Data[0])
207+
}
208+
209+
// String returns the bit string representation (e.g., "10101").
210+
func (b Bit) String() string {
211+
length := b.Len()
212+
if length == 0 {
213+
return ""
214+
}
215+
var sb strings.Builder
216+
sb.Grow(length)
217+
padding := int(b.Data[0])
218+
bitData := b.Data[1:]
219+
for i := range length {
220+
bitPos := padding + i
221+
byteIdx := bitPos / 8
222+
bitIdx := 7 - (bitPos % 8)
223+
if (bitData[byteIdx] & (1 << bitIdx)) != 0 {
224+
sb.WriteByte('1')
225+
} else {
226+
sb.WriteByte('0')
227+
}
228+
}
229+
return sb.String()
230+
}
231+
140232
func hugeIntToNative(hugeInt *mapping.HugeInt) *big.Int {
141233
lower, upper := mapping.HugeIntMembers(hugeInt)
142234
i := big.NewInt(upper)

0 commit comments

Comments
 (0)