Skip to content

Commit a1c982b

Browse files
authored
Fix: handle sql.NullTime parameters (#195)
* handle sql.NullTime parameters * Match SQL sizes for sql.Nullxxx integer types * handle custom nullable Valuer implementations
1 parent 3ed002a commit a1c982b

File tree

2 files changed

+61
-5
lines changed

2 files changed

+61
-5
lines changed

alwaysencrypted_test.go

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ type aeColumnInfo struct {
4141
sampleValue interface{}
4242
}
4343

44+
type customValuer struct {
45+
}
46+
47+
func (n customValuer) Value() (driver.Value, error) {
48+
return nil, nil
49+
}
50+
4451
func TestAlwaysEncryptedE2E(t *testing.T) {
4552
params := testConnParams(t)
4653
if !params.ColumnEncryption {
@@ -53,7 +60,11 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
5360
{"int", "INT", ColumnEncryptionDeterministic, int32(1)},
5461
{"nchar(10) COLLATE Latin1_General_BIN2", "NCHAR", ColumnEncryptionDeterministic, NChar("ncharval")},
5562
{"tinyint", "TINYINT", ColumnEncryptionRandomized, byte(2)},
63+
{"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: false}},
64+
{"tinyint", "TINYINT", ColumnEncryptionDeterministic, sql.NullByte{Valid: true, Byte: 1}},
5665
{"smallint", "SMALLINT", ColumnEncryptionDeterministic, int16(-3)},
66+
{"smallint", "SMALLINT", ColumnEncryptionRandomized, sql.NullInt16{Valid: false}},
67+
{"smallint", "SMALLINT", ColumnEncryptionDeterministic, sql.NullInt16{Valid: true, Int16: 32000}},
5768
{"bigint", "BIGINT", ColumnEncryptionRandomized, int64(4)},
5869
// We can't use fractional float/real values due to rounding errors in the round trip
5970
{"real", "REAL", ColumnEncryptionDeterministic, float32(5)},
@@ -67,9 +78,13 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
6778
{"datetime2(7)", "DATETIME2", ColumnEncryptionDeterministic, civil.DateTimeOf(dt)},
6879
{"nvarchar(max)", "NVARCHAR", ColumnEncryptionRandomized, NVarCharMax("nvarcharmaxval")},
6980
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: false}},
81+
{"int", "INT", ColumnEncryptionDeterministic, sql.NullInt32{Valid: true, Int32: -75000}},
7082
{"bigint", "BIGINT", ColumnEncryptionDeterministic, sql.NullInt64{Int64: 128, Valid: true}},
83+
{"bigint", "BIGINT", ColumnEncryptionRandomized, sql.NullInt64{Valid: false}},
7184
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}},
7285
{"uniqueidentifier", "UNIQUEIDENTIFIER", ColumnEncryptionRandomized, NullUniqueIdentifier{Valid: false}},
86+
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: false}},
87+
{"datetimeoffset(7)", "DATETIMEOFFSET", ColumnEncryptionDeterministic, sql.NullTime{Valid: true, Time: time.Now()}},
7388
}
7489
for _, test := range providerTests {
7590
// turn off key caching
@@ -108,7 +123,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
108123
_, _ = query.WriteString(fmt.Sprintf("CREATE TABLE [%s] (", tableName))
109124
_, _ = insert.WriteString(fmt.Sprintf("INSERT INTO [%s] VALUES (", tableName))
110125
_, _ = sel.WriteString("select top(1) ")
111-
insertArgs := make([]interface{}, len(encryptableColumns)+1)
126+
insertArgs := make([]interface{}, len(encryptableColumns)+2)
112127
for i, ec := range encryptableColumns {
113128
encType := "RANDOMIZED"
114129
null := ""
@@ -128,11 +143,13 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
128143
insert.WriteString(fmt.Sprintf("@p%d,", i+1))
129144
sel.WriteString(fmt.Sprintf("col%d,", i))
130145
}
131-
_, _ = query.WriteString("unencryptedcolumn nvarchar(100)")
146+
_, _ = query.WriteString("unencryptedcolumn nvarchar(100),")
147+
_, _ = query.WriteString("nullableCustomValuer int NULL")
132148
_, _ = query.WriteString(")")
133149
insertArgs[len(encryptableColumns)] = "unencryptedvalue"
134-
insert.WriteString(fmt.Sprintf("@p%d)", len(encryptableColumns)+1))
135-
sel.WriteString(fmt.Sprintf("unencryptedcolumn from [%s]", tableName))
150+
insertArgs[len(encryptableColumns)+1] = customValuer{}
151+
insert.WriteString(fmt.Sprintf("@p%d,@p%d)", len(encryptableColumns)+1, len(encryptableColumns)+2))
152+
sel.WriteString(fmt.Sprintf("unencryptedcolumn, nullableCustomValuer from [%s]", tableName))
136153
_, err = conn.Exec(query.String())
137154
assert.NoError(t, err, "Failed to create encrypted table")
138155
defer func() { _, _ = conn.Exec("DROP TABLE IF EXISTS " + tableName) }()
@@ -152,13 +169,15 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
152169
}
153170

154171
var unencryptedColumnValue string
155-
scanValues := make([]interface{}, len(encryptableColumns)+1)
172+
var nullint sql.NullInt32
173+
scanValues := make([]interface{}, len(encryptableColumns)+2)
156174
for v := range scanValues {
157175
if v < len(encryptableColumns) {
158176
scanValues[v] = new(interface{})
159177
}
160178
}
161179
scanValues[len(encryptableColumns)] = &unencryptedColumnValue
180+
scanValues[len(encryptableColumns)+1] = &nullint
162181
err = rows.Scan(scanValues...)
163182
defer rows.Close()
164183
if err != nil {
@@ -182,6 +201,7 @@ func TestAlwaysEncryptedE2E(t *testing.T) {
182201
assert.Equalf(t, expectedStrVal, strVal, "Incorrect value for col%d. ", i)
183202
}
184203
assert.Equalf(t, "unencryptedvalue", unencryptedColumnValue, "Got wrong value for unencrypted column")
204+
assert.False(t, nullint.Valid, "custom valuer should have null value")
185205
_ = rows.Next()
186206
err = rows.Err()
187207
assert.NoError(t, err, "rows.Err() has non-nil values")

mssql.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,19 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
983983
return
984984
}
985985
switch valuer := val.(type) {
986+
// sql.Nullxxx integer types return an int64. We want the original type, to match the SQL type size.
987+
case sql.NullByte:
988+
if valuer.Valid {
989+
return s.makeParam(valuer.Byte)
990+
}
991+
case sql.NullInt16:
992+
if valuer.Valid {
993+
return s.makeParam(valuer.Int16)
994+
}
995+
case sql.NullInt32:
996+
if valuer.Valid {
997+
return s.makeParam(valuer.Int32)
998+
}
986999
case UniqueIdentifier:
9871000
case NullUniqueIdentifier:
9881001
default:
@@ -1052,9 +1065,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
10521065
res.ti.Size = 8
10531066
res.buffer = []byte{}
10541067
case sql.NullInt32:
1068+
// only null values should be getting here
10551069
res.ti.TypeId = typeIntN
10561070
res.ti.Size = 4
10571071
res.buffer = []byte{}
1072+
case sql.NullInt16:
1073+
// only null values should be getting here
1074+
res.buffer = []byte{}
1075+
res.ti.Size = 2
1076+
res.ti.TypeId = typeIntN
1077+
case sql.NullByte:
1078+
// only null values should be getting here
1079+
res.buffer = []byte{}
1080+
res.ti.Size = 1
1081+
res.ti.TypeId = typeIntN
10581082
case byte:
10591083
res.ti.TypeId = typeIntN
10601084
res.buffer = []byte{val}
@@ -1110,6 +1134,18 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
11101134
res.buffer = encodeDateTime(val)
11111135
res.ti.Size = len(res.buffer)
11121136
}
1137+
case sql.NullTime: // only null values reach here
1138+
res.buffer = []byte{}
1139+
res.ti.Size = 8
1140+
if s.c.sess.loginAck.TDSVersion >= verTDS73 {
1141+
res.ti.TypeId = typeDateTimeOffsetN
1142+
res.ti.Scale = 7
1143+
} else {
1144+
res.ti.TypeId = typeDateTimeN
1145+
}
1146+
case driver.Valuer:
1147+
// We have a custom Valuer implementation with a nil value
1148+
return s.makeParam(nil)
11131149
default:
11141150
return s.makeParamExtra(val)
11151151
}

0 commit comments

Comments
 (0)