Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion server/doltgres_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,18 @@ func (h *DoltgresHandler) InitSessionParameterDefault(ctx context.Context, c *my
// convertBindParameters handles the conversion from bind parameters to variable values.
func (h *DoltgresHandler) convertBindParameters(ctx *sql.Context, types []uint32, formatCodes []int16, values [][]byte) (map[string]sqlparser.Expr, error) {
bindings := make(map[string]sqlparser.Expr, len(values))
// It's valid to send just one format code that should be used by all values, so we extend the slice in that case
if len(formatCodes) > 0 && len(formatCodes) < len(values) {
if len(formatCodes) > 1 {
return nil, errors.Errorf(`format codes have length "%d" but values have length "%d"`, len(formatCodes), len(values))
}
formatCode := formatCodes[0]
formatCodes = make([]int16, len(values))
formatCodes[0] = formatCode
for i := 1; i < len(values); i++ {
formatCodes[i] = formatCode
}
}
for i := range values {
formatCode := int16(0)
if formatCodes != nil {
Expand Down Expand Up @@ -542,7 +554,8 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, isPrepared bool)
typmod = doltgresType.GetAttTypMod() // pg_attribute.atttypmod
if isPrepared {
switch doltgresType.ID {
case pgtypes.Bytea.ID, pgtypes.Int16.ID, pgtypes.Int32.ID, pgtypes.Int64.ID, pgtypes.Uuid.ID:
case pgtypes.Bytea.ID, pgtypes.Date.ID, pgtypes.Int16.ID, pgtypes.Int32.ID, pgtypes.Int64.ID,
pgtypes.Timestamp.ID, pgtypes.TimestampTZ.ID, pgtypes.Uuid.ID:
formatCode = 1
}
}
Expand Down Expand Up @@ -805,6 +818,22 @@ func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row, isExecute bool) ([]
binary.BigEndian.PutUint16(buf, uint16(v.(int16)))
o[i] = buf
continue
case pgtypes.Timestamp.ID, pgtypes.TimestampTZ.ID:
postgresEpoch := time.UnixMilli(946684800000).UTC() // Jan 1, 2000 @ Midnight
deltaInMicroseconds := v.(time.Time).UTC().UnixMicro() - postgresEpoch.UnixMicro()
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf, uint64(deltaInMicroseconds))
o[i] = buf
continue
case pgtypes.Date.ID:
postgresEpoch := time.UnixMilli(946684800000).UTC() // Jan 1, 2000 @ Midnight
deltaInMilliseconds := v.(time.Time).UTC().UnixMilli() - postgresEpoch.UnixMilli()
buf := make([]byte, 4)
const millisecondsPerDay = 86400000
days := deltaInMilliseconds / millisecondsPerDay
binary.BigEndian.PutUint32(buf, uint32(days))
o[i] = buf
continue
case pgtypes.Uuid.ID:
buf, err := v.(uuid.UUID).MarshalBinary()
if err != nil {
Expand Down
6 changes: 1 addition & 5 deletions server/functions/make_timestamp.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ var make_timestamp = framework.Function6{
IsNonDeterministic: true,
Strict: true,
Callable: func(ctx *sql.Context, _ [7]*pgtypes.DoltgresType, val1, val2, val3, val4, val5, val6 any) (any, error) {
loc, err := GetServerLocation(ctx)
if err != nil {
return time.Time{}, err
}
return getTimestampInServerLocation(val1.(int32), val2.(int32), val3.(int32), val4.(int32), val5.(int32), val6.(float64), loc)
return getTimestampInServerLocation(val1.(int32), val2.(int32), val3.(int32), val4.(int32), val5.(int32), val6.(float64), time.UTC)
},
}

Expand Down
6 changes: 5 additions & 1 deletion server/functions/to_date.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package functions

import (
"time"

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/server/functions/framework"
Expand All @@ -37,6 +39,8 @@ var to_date_text_text = framework.Function2{
format := val2.(string)

// Parse the date using PostgreSQL format patterns
return getDateTimeFromFormat(ctx, input, format)
t, err := getDateTimeFromFormat(ctx, input, format)
// We return a version of the time but with the timezone completely stripped since they do not use timezones
return time.Date(t.Year(), t.Month(), t.Day(), 0, 0, 0, 0, time.UTC), err
},
}
27 changes: 23 additions & 4 deletions testing/go/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ import (

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/postgres/parser/duration"
"github.com/dolthub/doltgresql/postgres/parser/sem/tree"
"github.com/dolthub/doltgresql/postgres/parser/timeofday"
"github.com/dolthub/doltgresql/postgres/parser/uuid"
dserver "github.com/dolthub/doltgresql/server"
Expand Down Expand Up @@ -589,6 +590,20 @@ func NormalizeExpectedRow(fds []pgconn.FieldDescription, rows []sql.Row) []sql.R
panic(err)
}
newRow[i] = ret
} else if dt.ID == types.Date.ID {
newRow[i] = row[i]
if row[i] != nil {
if t, _, err := tree.ParseDTimestampTZ(nil, row[i].(string), tree.TimeFamilyPrecisionToRoundDuration(6), time.UTC); err == nil {
newRow[i] = functions.FormatDateTimeWithBC(t.Time.UTC(), "2006-01-02", dt.ID == types.TimestampTZ.ID)
}
}
} else if dt.ID == types.Timestamp.ID || dt.ID == types.TimestampTZ.ID {
newRow[i] = row[i]
if row[i] != nil {
if t, _, err := tree.ParseDTimestampTZ(nil, row[i].(string), tree.TimeFamilyPrecisionToRoundDuration(6), time.UTC); err == nil {
newRow[i] = functions.FormatDateTimeWithBC(t.Time.UTC(), "2006-01-02 15:04:05.999999", dt.ID == types.TimestampTZ.ID)
}
}
} else {
newRow[i] = NormalizeIntsAndFloats(row[i])
}
Expand Down Expand Up @@ -663,7 +678,7 @@ func NormalizeValToString(dt *types.DoltgresType, v any) any {
panic(err)
}
return val
case types.Interval.ID, types.Uuid.ID, types.Date.ID, types.Time.ID, types.Timestamp.ID:
case types.Interval.ID, types.Time.ID, types.Uuid.ID:
// These values need to be normalized into the appropriate types
// before being converted to string type using the Doltgres
// IoOutput method.
Expand All @@ -675,12 +690,16 @@ func NormalizeValToString(dt *types.DoltgresType, v any) any {
panic(err)
}
return tVal
case types.TimestampTZ.ID:
case types.Date.ID:
if v == nil {
return nil
}
return functions.FormatDateTimeWithBC(v.(time.Time), "2006-01-02", false)
case types.Timestamp.ID, types.TimestampTZ.ID:
if v == nil {
return nil
}
// timestamptz returns a value in server timezone
return functions.FormatDateTimeWithBC(v.(time.Time), "2006-01-02 15:04:05.999999", true)
return functions.FormatDateTimeWithBC(v.(time.Time).UTC(), "2006-01-02 15:04:05.999999", dt.ID == types.TimestampTZ.ID)
}

switch val := v.(type) {
Expand Down
2 changes: 1 addition & 1 deletion testing/postgres-client-tests/postgres-client-tests.bats
Original file line number Diff line number Diff line change
Expand Up @@ -103,5 +103,5 @@ teardown() {

@test "rust sqlx" {
cd $BATS_TEST_DIRNAME/rust
cargo run -- $USER $PORT
RUSTFLAGS=-Awarnings cargo run -- $USER $PORT
}
2 changes: 1 addition & 1 deletion testing/postgres-client-tests/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ edition = "2024"

[dependencies]
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "tls-native-tls"] }
sqlx = { version = "0.8", features = ["runtime-tokio", "postgres", "tls-native-tls", "uuid", "chrono"] }
61 changes: 55 additions & 6 deletions testing/postgres-client-tests/rust/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
use sqlx::postgres::PgPoolOptions;
use sqlx::types::Uuid;
use sqlx::types::chrono::Utc;
use sqlx::types::chrono::DateTime;
use sqlx::types::chrono::NaiveDate;

#[derive(sqlx::FromRow)]
struct Event {
id: sqlx::types::Uuid,
created_at: DateTime<Utc>,
event_date: Option<NaiveDate>,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
Expand All @@ -15,12 +26,50 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.max_connections(5)
.connect(&database_url)
.await?;
let exists: bool = sqlx::query_scalar(
"SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = $1);"
)
.bind("test_table")
.fetch_one(&pool)
.await?;

let exists: bool = sqlx::query_scalar("SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = $1);")
.bind("test_table")
.fetch_one(&pool)
.await?;
println!("exists={exists}");

sqlx::query("DROP TABLE IF EXISTS users, events;")
.execute(&pool)
.await?;

sqlx::query("CREATE TABLE users (id uuid default gen_random_uuid(), name text, email text);")
.execute(&pool)
.await?;

sqlx::query("INSERT INTO users (name, email) VALUES ($1, $2)")
.bind("Alice")
.bind("alice@example.com")
.execute(&pool)
.await?;

let some_uuid: Uuid = sqlx::query_scalar("SELECT id FROM users WHERE email = $1 LIMIT 1")
.bind("alice@example.com")
.fetch_one(&pool)
.await?;

sqlx::query("UPDATE users SET name = $1 WHERE id = $2")
.bind("Bob")
.bind(some_uuid)
.execute(&pool)
.await?;

sqlx::query("CREATE TABLE events (id UUID PRIMARY KEY, created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), event_date DATE);")
.execute(&pool)
.await?;

let some_id: Uuid = sqlx::query_scalar("INSERT INTO events (id, event_date) VALUES (gen_random_uuid(), '2026-02-12') RETURNING id;")
.fetch_one(&pool)
.await?;

let __event = sqlx::query_as::<_, Event>("SELECT * FROM events WHERE id = $1")
.bind(some_id)
.fetch_one(&pool)
.await?;

Ok(())
}