Skip to content

Commit b244179

Browse files
committed
Add BIT support
This adds support for reading/writing the BIT type. Since there is no native type for representing bitstrings in Go, this adds a custom Go type that represents one. It uses the same internal storage as DuckDB so that conversions are cheap and require minimal copying. P.S. Strictly speaking there is the BitString type in the encoding/asn1, but that's not meant for general purpose use. It's really only there so that X509 certificates can be decoded. It's also laid out very differently than DuckDB its BIT type (e.g. it uses 0's for padding instead of 1's).
1 parent fe23e71 commit b244179

File tree

15 files changed

+316
-23
lines changed

15 files changed

+316
-23
lines changed

appender_test.go

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,70 @@ func TestAppenderTimeTZ(t *testing.T) {
818818
require.Equal(t, expected, r)
819819
}
820820

821+
func TestAppenderBit(t *testing.T) {
822+
c, db, conn, a := prepareAppender(t, `CREATE TABLE test (data BIT)`)
823+
defer cleanupAppender(t, c, db, conn, a)
824+
825+
expected := []string{
826+
"10101",
827+
"11110000",
828+
"1",
829+
"111100001100110011",
830+
"00000000000000000000000010110101",
831+
}
832+
for _, bits := range expected {
833+
bitVal, err := NewBitFromString(bits)
834+
require.NoError(t, err)
835+
require.NoError(t, a.AppendRow(bitVal))
836+
}
837+
838+
require.NoError(t, a.Flush())
839+
840+
// Verify results.
841+
res, err := db.QueryContext(context.Background(), `SELECT data FROM test`)
842+
require.NoError(t, err)
843+
defer closeRowsWrapper(t, res)
844+
845+
i := 0
846+
for res.Next() {
847+
var b Bit
848+
require.NoError(t, res.Scan(&b))
849+
require.Equal(t, expected[i], b.String())
850+
i++
851+
}
852+
require.Equal(t, len(expected), i)
853+
}
854+
855+
func TestAppenderNullBit(t *testing.T) {
856+
c, db, conn, a := prepareAppender(t, `CREATE TABLE test (b BIT)`)
857+
defer cleanupAppender(t, c, db, conn, a)
858+
859+
// Append a nil *Bit.
860+
var nilBit *Bit
861+
require.NoError(t, a.AppendRow(nilBit))
862+
863+
// Append a non-nil Bit.
864+
nonNilBit, err := NewBitFromString("10101")
865+
require.NoError(t, err)
866+
require.NoError(t, a.AppendRow(nonNilBit))
867+
868+
require.NoError(t, a.Flush())
869+
870+
// Verify results.
871+
rows, err := db.QueryContext(context.Background(), `SELECT b FROM test`)
872+
require.NoError(t, err)
873+
defer closeRowsWrapper(t, rows)
874+
875+
require.True(t, rows.Next())
876+
var r *Bit
877+
require.NoError(t, rows.Scan(&r))
878+
require.Nil(t, r)
879+
880+
require.True(t, rows.Next())
881+
require.NoError(t, rows.Scan(&r))
882+
require.Equal(t, &nonNilBit, r)
883+
}
884+
821885
func TestAppenderBlob(t *testing.T) {
822886
c, db, conn, a := prepareAppender(t, `CREATE TABLE test (data BLOB)`)
823887
defer cleanupAppender(t, c, db, conn, a)

connection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func newConn(conn mapping.Connection, ctxStore *contextStore) *Conn {
3737
// CheckNamedValue implements the driver.NamedValueChecker interface.
3838
func (conn *Conn) CheckNamedValue(nv *driver.NamedValue) error {
3939
switch nv.Value.(type) {
40-
case *big.Int, Interval, []any, []bool, []int8, []int16, []int32, []int64, []int, []uint8, []uint16,
40+
case *big.Int, Interval, Bit, []any, []bool, []int8, []int16, []int32, []int64, []int, []uint8, []uint16,
4141
[]uint32, []uint64, []uint, []float32, []float64, []string, map[string]any:
4242
return nil
4343
}

errors_test.go

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -150,23 +150,6 @@ func TestErrAppender(t *testing.T) {
150150
testError(t, err, errAppenderDoubleClose.Error())
151151
})
152152

153-
t.Run(unsupportedTypeErrMsg, func(t *testing.T) {
154-
c := newConnectorWrapper(t, ``, nil)
155-
defer closeConnectorWrapper(t, c)
156-
157-
db := sql.OpenDB(c)
158-
_, err := db.Exec(`CREATE TABLE test (bit_col BIT)`)
159-
require.NoError(t, err)
160-
defer closeDbWrapper(t, db)
161-
162-
conn := openDriverConnWrapper(t, c)
163-
defer closeDriverConnWrapper(t, &conn)
164-
165-
a, err := NewAppenderFromConn(conn, "", "test")
166-
defer closeAppenderWrapper(t, a)
167-
testError(t, err, errAppenderCreation.Error(), unsupportedTypeErrMsg)
168-
})
169-
170153
t.Run(columnCountErrMsg, func(t *testing.T) {
171154
c, db, conn, a := prepareAppender(t, `CREATE TABLE test (a VARCHAR, b VARCHAR)`)
172155
defer cleanupAppender(t, c, db, conn, a)

mapping/mapping.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88

99
type Type = bindings.Type
1010

11+
//nolint:dupl
1112
const (
1213
TypeInvalid = bindings.TypeInvalid
1314
TypeBoolean = bindings.TypeBoolean
@@ -196,6 +197,7 @@ type (
196197

197198
// Helper functions for types without internal pointers.
198199

200+
//nolint:dupl
199201
var (
200202
NewDate = bindings.NewDate
201203
DateMembers = bindings.DateMembers
@@ -223,6 +225,8 @@ var (
223225
TimestampStructMembers = bindings.TimestampStructMembers
224226
NewInterval = bindings.NewInterval
225227
IntervalMembers = bindings.IntervalMembers
228+
NewBit = bindings.NewBit
229+
BitMembers = bindings.BitMembers
226230
NewHugeInt = bindings.NewHugeInt
227231
HugeIntMembers = bindings.HugeIntMembers
228232
NewUHugeInt = bindings.NewUHugeInt

rows.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,8 @@ func (r *rows) getScanType(logicalType mapping.LogicalType, index mapping.IdxT)
132132
return reflectTypeString
133133
case TYPE_BLOB:
134134
return reflectTypeBytes
135+
case TYPE_BIT:
136+
return reflectTypeBit
135137
case TYPE_DECIMAL:
136138
return reflectTypeDecimal
137139
case TYPE_LIST:

statement.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,15 @@ func (s *Stmt) bindUUID(val driver.NamedValue, n int) (mapping.State, error) {
293293
return mapping.StateError, addIndexToError(unsupportedTypeError(unknownTypeErrMsg), n+1)
294294
}
295295

296+
func (s *Stmt) bindBit(val *Bit, n int) (mapping.State, error) {
297+
bit := mapping.NewBit(val.Data)
298+
defer mapping.DestroyBit(&bit)
299+
v := mapping.CreateBit(bit)
300+
defer mapping.DestroyValue(&v)
301+
state := mapping.BindValue(*s.preparedStmt, mapping.IdxT(n+1), v)
302+
return state, nil
303+
}
304+
296305
// Used for binding Array, List, Struct. In the future, also Map and Union
297306
func (s *Stmt) bindCompositeValue(val driver.NamedValue, n int) (mapping.State, error) {
298307
lt, err := s.paramLogicalType(n + 1)
@@ -435,6 +444,8 @@ func (s *Stmt) bindValue(val driver.NamedValue, n int) (mapping.State, error) {
435444
return mapping.StateError, inferErr
436445
}
437446
return mapping.BindInterval(*s.preparedStmt, mapping.IdxT(n+1), i), nil
447+
case Bit:
448+
return s.bindBit(&v, n)
438449
case nil:
439450
return mapping.BindNull(*s.preparedStmt, mapping.IdxT(n+1)), nil
440451
}

type.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ const (
5151
// FIXME: Implement support for these types.
5252
var unsupportedTypeToStringMap = map[Type]string{
5353
TYPE_INVALID: "INVALID",
54-
TYPE_BIT: "BIT",
5554
TYPE_ANY: "ANY",
5655
}
5756

type_info.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,7 @@ func (info *typeInfo) Details() TypeDetails {
211211
// Valid types are:
212212
// TYPE_[BOOLEAN, TINYINT, SMALLINT, INTEGER, BIGINT, UTINYINT, USMALLINT, UINTEGER,
213213
// UBIGINT, FLOAT, DOUBLE, TIMESTAMP, DATE, TIME, INTERVAL, HUGEINT, UHUGEINT, VARCHAR,
214-
// BLOB, TIMESTAMP_S, TIMESTAMP_MS, TIMESTAMP_NS, UUID, TIMESTAMP_TZ, ANY].
214+
// BLOB, BIT, TIMESTAMP_S, TIMESTAMP_MS, TIMESTAMP_NS, UUID, TIMESTAMP_TZ, ANY].
215215
func NewTypeInfo(t Type) (TypeInfo, error) {
216216
name, inMap := unsupportedTypeToStringMap[t]
217217
if inMap && t != TYPE_ANY {
@@ -418,7 +418,7 @@ func (info *typeInfo) logicalType() mapping.LogicalType {
418418
case TYPE_BOOLEAN, TYPE_TINYINT, TYPE_SMALLINT, TYPE_INTEGER, TYPE_BIGINT, TYPE_UTINYINT, TYPE_USMALLINT,
419419
TYPE_UINTEGER, TYPE_UBIGINT, TYPE_FLOAT, TYPE_DOUBLE, TYPE_TIMESTAMP, TYPE_TIMESTAMP_S, TYPE_TIMESTAMP_MS,
420420
TYPE_TIMESTAMP_NS, TYPE_TIMESTAMP_TZ, TYPE_DATE, TYPE_TIME, TYPE_TIME_TZ, TYPE_INTERVAL, TYPE_HUGEINT,
421-
TYPE_UHUGEINT, TYPE_VARCHAR, TYPE_BLOB, TYPE_UUID, TYPE_ANY:
421+
TYPE_UHUGEINT, TYPE_VARCHAR, TYPE_BLOB, TYPE_BIT, TYPE_UUID, TYPE_ANY:
422422
return mapping.CreateLogicalType(info.Type)
423423
case TYPE_DECIMAL:
424424
return mapping.CreateDecimalType(info.decimalWidth, info.decimalScale)

type_info_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ var testPrimitiveSQLValues = map[Type]testTypeValues{
3939
TYPE_BIGNUM: {input: `46::BIGNUM`, output: `46`},
4040
TYPE_VARCHAR: {input: `'hello world'::VARCHAR`, output: `hello world`},
4141
TYPE_BLOB: {input: `'\xAA'::BLOB`, output: `[170]`},
42+
TYPE_BIT: {input: `'10101'::BIT`, output: `10101`},
4243
TYPE_TIMESTAMP_S: {input: `TIMESTAMP_S '1992-09-20 11:30:00'`, output: `1992-09-20 11:30:00 +0000 UTC`},
4344
TYPE_TIMESTAMP_MS: {input: `TIMESTAMP_MS '1992-09-20 11:30:00.123'`, output: `1992-09-20 11:30:00.123 +0000 UTC`},
4445
TYPE_TIMESTAMP_NS: {input: `TIMESTAMP_NS '1992-09-20 11:30:00.123456789'`, output: `1992-09-20 11:30:00.123456789 +0000 UTC`},

types.go

Lines changed: 93 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
)
1818

1919
// duckdb-go exports the following type wrappers:
20-
// UUID, Map, Interval, Decimal, Union, Composite (optional, used to scan LIST and STRUCT).
20+
// UUID, Bit, Map, Interval, Decimal, Union, Composite (optional, used to scan LIST and STRUCT).
2121

2222
// Pre-computed reflect type values to avoid repeated allocations.
2323
var (
@@ -44,6 +44,7 @@ var (
4444
reflectTypeUnion = reflect.TypeFor[Union]()
4545
reflectTypeAny = reflect.TypeFor[any]()
4646
reflectTypeUUID = reflect.TypeFor[UUID]()
47+
reflectTypeBit = reflect.TypeFor[Bit]()
4748
)
4849

4950
type numericType interface {
@@ -136,6 +137,97 @@ func uuidToHugeInt(uuid UUID) mapping.HugeInt {
136137
return mapping.NewHugeInt(lower, int64(upper^(1<<63)))
137138
}
138139

140+
// Bit represents a DuckDB BIT value as a sequence of bits.
141+
// Data stores DuckDB's internal format: a padding-count prefix byte followed by
142+
// the bit bytes (right-aligned with 1-padded MSB bits).
143+
// For example, "10101" (5 bits) is stored as [3, 11110101] where 3 is the padding count.
144+
type Bit struct {
145+
Data []byte
146+
}
147+
148+
// NewBitFromString creates a Bit from a string of '0' and '1' characters.
149+
func NewBitFromString(s string) (Bit, error) {
150+
if len(s) == 0 {
151+
return Bit{}, nil
152+
}
153+
154+
numBytes := (len(s) + 7) / 8
155+
padding := (8 - (len(s) % 8)) % 8
156+
data := make([]byte, numBytes+1)
157+
data[0] = byte(padding)
158+
159+
// Set padding bits to 1
160+
if padding > 0 {
161+
data[1] = byte(0xFF) << (8 - padding)
162+
}
163+
164+
for i, c := range s {
165+
switch c {
166+
case '1':
167+
bitPos := padding + i
168+
byteIdx := bitPos/8 + 1
169+
bitIdx := 7 - (bitPos % 8)
170+
data[byteIdx] |= 1 << bitIdx
171+
case '0':
172+
default:
173+
return Bit{}, fmt.Errorf("invalid character in bit string: %c", c)
174+
}
175+
}
176+
177+
return Bit{Data: data}, nil
178+
}
179+
180+
// Validate checks that Data is a valid DuckDB bit encoding: the padding count
181+
// (first byte) must be 0-7, and the padding bits in the first data byte must
182+
// all be set to 1.
183+
func (b Bit) Validate() error {
184+
if len(b.Data) <= 1 {
185+
return nil
186+
}
187+
padding := int(b.Data[0])
188+
if padding > 7 {
189+
return fmt.Errorf("invalid padding count %d, must be 0-7", padding)
190+
}
191+
if padding > 0 {
192+
expectedMask := byte(0xFF) << (8 - padding)
193+
if (b.Data[1] & expectedMask) != expectedMask {
194+
return fmt.Errorf("padding bits must be 1s, expected high %d bits of first byte to be set", padding)
195+
}
196+
}
197+
return nil
198+
}
199+
200+
// Len returns the number of bits.
201+
func (b Bit) Len() int {
202+
if len(b.Data) == 0 {
203+
return 0
204+
}
205+
return (len(b.Data)-1)*8 - int(b.Data[0])
206+
}
207+
208+
// String returns the bit string representation (e.g., "10101").
209+
func (b Bit) String() string {
210+
length := b.Len()
211+
if length == 0 {
212+
return ""
213+
}
214+
var sb strings.Builder
215+
sb.Grow(length)
216+
padding := int(b.Data[0])
217+
bitData := b.Data[1:]
218+
for i := range length {
219+
bitPos := padding + i
220+
byteIdx := bitPos / 8
221+
bitIdx := 7 - (bitPos % 8)
222+
if (bitData[byteIdx] & (1 << bitIdx)) != 0 {
223+
sb.WriteByte('1')
224+
} else {
225+
sb.WriteByte('0')
226+
}
227+
}
228+
return sb.String()
229+
}
230+
139231
func hugeIntToNative(hugeInt *mapping.HugeInt) *big.Int {
140232
lower, upper := mapping.HugeIntMembers(hugeInt)
141233
i := big.NewInt(upper)

0 commit comments

Comments
 (0)