Skip to content

Commit ca7b2e9

Browse files
committed
Fixed issues 2316 and 2318
1 parent a639beb commit ca7b2e9

File tree

7 files changed

+116
-19
lines changed

7 files changed

+116
-19
lines changed

server/doltgres_handler.go

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,18 @@ func (h *DoltgresHandler) InitSessionParameterDefault(ctx context.Context, c *my
288288
// convertBindParameters handles the conversion from bind parameters to variable values.
289289
func (h *DoltgresHandler) convertBindParameters(ctx *sql.Context, types []uint32, formatCodes []int16, values [][]byte) (map[string]sqlparser.Expr, error) {
290290
bindings := make(map[string]sqlparser.Expr, len(values))
291+
// It's valid to send just one format code that should be used by all values, so we extend the slice in that case
292+
if len(formatCodes) > 0 && len(formatCodes) < len(values) {
293+
if len(formatCodes) > 1 {
294+
return nil, errors.Errorf(`format codes have length "%d" but values have length "%d"`, len(formatCodes), len(values))
295+
}
296+
formatCode := formatCodes[0]
297+
formatCodes = make([]int16, len(values))
298+
formatCodes[0] = formatCode
299+
for i := 1; i < len(values); i++ {
300+
formatCodes[i] = formatCode
301+
}
302+
}
291303
for i := range values {
292304
formatCode := int16(0)
293305
if formatCodes != nil {
@@ -542,7 +554,8 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, isPrepared bool)
542554
typmod = doltgresType.GetAttTypMod() // pg_attribute.atttypmod
543555
if isPrepared {
544556
switch doltgresType.ID {
545-
case pgtypes.Bytea.ID, pgtypes.Int16.ID, pgtypes.Int32.ID, pgtypes.Int64.ID, pgtypes.Uuid.ID:
557+
case pgtypes.Bytea.ID, pgtypes.Date.ID, pgtypes.Int16.ID, pgtypes.Int32.ID, pgtypes.Int64.ID,
558+
pgtypes.Timestamp.ID, pgtypes.TimestampTZ.ID, pgtypes.Uuid.ID:
546559
formatCode = 1
547560
}
548561
}
@@ -805,6 +818,22 @@ func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row, isExecute bool) ([]
805818
binary.BigEndian.PutUint16(buf, uint16(v.(int16)))
806819
o[i] = buf
807820
continue
821+
case pgtypes.Timestamp.ID, pgtypes.TimestampTZ.ID:
822+
postgresEpoch := time.UnixMilli(946684800000).UTC() // Jan 1, 2000 @ Midnight
823+
deltaInMicroseconds := v.(time.Time).UTC().UnixMicro() - postgresEpoch.UnixMicro()
824+
buf := make([]byte, 8)
825+
binary.BigEndian.PutUint64(buf, uint64(deltaInMicroseconds))
826+
o[i] = buf
827+
continue
828+
case pgtypes.Date.ID:
829+
postgresEpoch := time.UnixMilli(946684800000).UTC() // Jan 1, 2000 @ Midnight
830+
deltaInMilliseconds := v.(time.Time).UTC().UnixMilli() - postgresEpoch.UnixMilli()
831+
buf := make([]byte, 4)
832+
const millisecondsPerDay = 86400000
833+
days := deltaInMilliseconds / millisecondsPerDay
834+
binary.BigEndian.PutUint32(buf, uint32(days))
835+
o[i] = buf
836+
continue
808837
case pgtypes.Uuid.ID:
809838
buf, err := v.(uuid.UUID).MarshalBinary()
810839
if err != nil {

server/functions/make_timestamp.go

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,7 @@ var make_timestamp = framework.Function6{
4242
IsNonDeterministic: true,
4343
Strict: true,
4444
Callable: func(ctx *sql.Context, _ [7]*pgtypes.DoltgresType, val1, val2, val3, val4, val5, val6 any) (any, error) {
45-
loc, err := GetServerLocation(ctx)
46-
if err != nil {
47-
return time.Time{}, err
48-
}
49-
return getTimestampInServerLocation(val1.(int32), val2.(int32), val3.(int32), val4.(int32), val5.(int32), val6.(float64), loc)
45+
return getTimestampInServerLocation(val1.(int32), val2.(int32), val3.(int32), val4.(int32), val5.(int32), val6.(float64), time.UTC)
5046
},
5147
}
5248

server/functions/to_date.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
package functions
1616

1717
import (
18+
"time"
19+
1820
"github.com/dolthub/go-mysql-server/sql"
1921

2022
"github.com/dolthub/doltgresql/server/functions/framework"
@@ -37,6 +39,8 @@ var to_date_text_text = framework.Function2{
3739
format := val2.(string)
3840

3941
// Parse the date using PostgreSQL format patterns
40-
return getDateTimeFromFormat(ctx, input, format)
42+
t, err := getDateTimeFromFormat(ctx, input, format)
43+
// We return a version of the time but with the timezone completely stripped since they do not use timezones
44+
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC), err
4145
},
4246
}

testing/go/framework.go

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import (
4242

4343
"github.com/dolthub/doltgresql/core/id"
4444
"github.com/dolthub/doltgresql/postgres/parser/duration"
45+
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
4546
"github.com/dolthub/doltgresql/postgres/parser/timeofday"
4647
"github.com/dolthub/doltgresql/postgres/parser/uuid"
4748
dserver "github.com/dolthub/doltgresql/server"
@@ -589,6 +590,20 @@ func NormalizeExpectedRow(fds []pgconn.FieldDescription, rows []sql.Row) []sql.R
589590
panic(err)
590591
}
591592
newRow[i] = ret
593+
} else if dt.ID == types.Date.ID {
594+
newRow[i] = row[i]
595+
if row[i] != nil {
596+
if t, _, err := tree.ParseDTimestampTZ(nil, row[i].(string), tree.TimeFamilyPrecisionToRoundDuration(6), time.UTC); err == nil {
597+
newRow[i] = functions.FormatDateTimeWithBC(t.Time.UTC(), "2006-01-02", dt.ID == types.TimestampTZ.ID)
598+
}
599+
}
600+
} else if dt.ID == types.Timestamp.ID || dt.ID == types.TimestampTZ.ID {
601+
newRow[i] = row[i]
602+
if row[i] != nil {
603+
if t, _, err := tree.ParseDTimestampTZ(nil, row[i].(string), tree.TimeFamilyPrecisionToRoundDuration(6), time.UTC); err == nil {
604+
newRow[i] = functions.FormatDateTimeWithBC(t.Time.UTC(), "2006-01-02 15:04:05.999999", dt.ID == types.TimestampTZ.ID)
605+
}
606+
}
592607
} else {
593608
newRow[i] = NormalizeIntsAndFloats(row[i])
594609
}
@@ -663,7 +678,7 @@ func NormalizeValToString(dt *types.DoltgresType, v any) any {
663678
panic(err)
664679
}
665680
return val
666-
case types.Interval.ID, types.Uuid.ID, types.Date.ID, types.Time.ID, types.Timestamp.ID:
681+
case types.Interval.ID, types.Time.ID, types.Uuid.ID:
667682
// These values need to be normalized into the appropriate types
668683
// before being converted to string type using the Doltgres
669684
// IoOutput method.
@@ -675,12 +690,16 @@ func NormalizeValToString(dt *types.DoltgresType, v any) any {
675690
panic(err)
676691
}
677692
return tVal
678-
case types.TimestampTZ.ID:
693+
case types.Date.ID:
694+
if v == nil {
695+
return nil
696+
}
697+
return functions.FormatDateTimeWithBC(v.(time.Time), "2006-01-02", false)
698+
case types.Timestamp.ID, types.TimestampTZ.ID:
679699
if v == nil {
680700
return nil
681701
}
682-
// timestamptz returns a value in server timezone
683-
return functions.FormatDateTimeWithBC(v.(time.Time), "2006-01-02 15:04:05.999999", true)
702+
return functions.FormatDateTimeWithBC(v.(time.Time).UTC(), "2006-01-02 15:04:05.999999", dt.ID == types.TimestampTZ.ID)
684703
}
685704

686705
switch val := v.(type) {

testing/postgres-client-tests/postgres-client-tests.bats

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,5 +103,5 @@ teardown() {
103103

104104
@test "rust sqlx" {
105105
cd $BATS_TEST_DIRNAME/rust
106-
cargo run -- $USER $PORT
106+
RUSTFLAGS=-Awarnings cargo run -- $USER $PORT
107107
}

testing/postgres-client-tests/rust/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ edition = "2024"
55

66
[dependencies]
77
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
8-
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "tls-native-tls"] }
8+
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "tls-native-tls", "uuid", "chrono"] }

testing/postgres-client-tests/rust/src/main.rs

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,15 @@
11
use sqlx::postgres::PgPoolOptions;
2+
use sqlx::types::Uuid;
3+
use sqlx::types::chrono::Utc;
4+
use sqlx::types::chrono::DateTime;
5+
use sqlx::types::chrono::NaiveDate;
6+
7+
#[derive(sqlx::FromRow)]
8+
struct Event {
9+
id: sqlx::types::Uuid,
10+
created_at: DateTime<Utc>,
11+
event_date: Option<NaiveDate>,
12+
}
213

314
#[tokio::main]
415
async fn main() -> Result<(), Box<dyn std::error::Error>> {
@@ -15,12 +26,50 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
1526
.max_connections(5)
1627
.connect(&database_url)
1728
.await?;
18-
let exists: bool = sqlx::query_scalar(
19-
"SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = $1);"
20-
)
21-
.bind("test_table")
22-
.fetch_one(&pool)
23-
.await?;
29+
30+
let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = $1);")
31+
.bind("test_table")
32+
.fetch_one(&pool)
33+
.await?;
2434
println!("exists={exists}");
35+
36+
sqlx::query("DROP TABLE IF EXISTS users, events;")
37+
.execute(&pool)
38+
.await?;
39+
40+
sqlx::query("CREATE TABLE users (id uuid default gen_random_uuid(), name text, email text);")
41+
.execute(&pool)
42+
.await?;
43+
44+
sqlx::query("INSERT INTO users (name, email) VALUES ($1, $2)")
45+
.bind("Alice")
46+
.bind("alice@example.com")
47+
.execute(&pool)
48+
.await?;
49+
50+
let some_uuid: Uuid = sqlx::query_scalar("SELECT id FROM users WHERE email = $1 LIMIT 1")
51+
.bind("alice@example.com")
52+
.fetch_one(&pool)
53+
.await?;
54+
55+
sqlx::query("UPDATE users SET name = $1 WHERE id = $2")
56+
.bind("Bob")
57+
.bind(some_uuid)
58+
.execute(&pool)
59+
.await?;
60+
61+
sqlx::query("CREATE TABLE events (id UUID PRIMARY KEY, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), event_date DATE);")
62+
.execute(&pool)
63+
.await?;
64+
65+
let some_id: Uuid = sqlx::query_scalar("INSERT INTO events (id, event_date) VALUES (gen_random_uuid(), '2026-02-12') RETURNING id;")
66+
.fetch_one(&pool)
67+
.await?;
68+
69+
let __event = sqlx::query_as::<_, Event>("SELECT * FROM events WHERE id = $1")
70+
.bind(some_id)
71+
.fetch_one(&pool)
72+
.await?;
73+
2574
Ok(())
2675
}

0 commit comments

Comments
 (0)