Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions server/analyzer/resolve_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ func ResolveTypeForNodes(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node,
same = transform.NewTree
col.Type = dt
}
resolvedDefault, err := resolveDefaultColumnType(ctx, col.Default)
if err != nil {
return nil, transform.NewTree, err
}
resolvedGenerated, err := resolveDefaultColumnType(ctx, col.Generated)
if err != nil {
return nil, transform.NewTree, err
}
resolvedOnUpdate, err := resolveDefaultColumnType(ctx, col.OnUpdate)
if err != nil {
return nil, transform.NewTree, err
}
if resolvedDefault || resolvedGenerated || resolvedOnUpdate {
same = transform.NewTree
}
}
return node, same, nil
case *plan.ModifyColumn:
Expand Down Expand Up @@ -187,3 +202,19 @@ func resolveType(ctx *sql.Context, typ *pgtypes.DoltgresType) (*pgtypes.Doltgres
}
return resolvedTyp, nil
}

// resolveDefaultColumnType resolves the OutType of a *sql.ColumnDefaultValue if it's not nil (and not already resolved).
func resolveDefaultColumnType(ctx *sql.Context, defaultVal *sql.ColumnDefaultValue) (bool, error) {
if defaultVal == nil {
return false, nil
}
if rt, ok := defaultVal.OutType.(*pgtypes.DoltgresType); ok && !rt.IsResolvedType() {
dt, err := resolveType(ctx, rt)
if err != nil {
return false, err
}
defaultVal.OutType = dt
return true, nil
}
return false, nil
}
2 changes: 1 addition & 1 deletion server/analyzer/type_sanitizer.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func TypeSanitizer(ctx *sql.Context, a *analyzer.Analyzer, node sql.Node, scope
}
}
case *plan.ExistsSubquery:
return pgexprs.NewGMSCast(expr), transform.NewTree, nil
return pgexprs.NewExplicitCast(pgexprs.NewGMSCast(expr), pgtypes.Bool), transform.NewTree, nil
case *sql.ColumnDefaultValue:
// Due to how interfaces work, we sometimes pass (*ColumnDefaultValue)(nil), so we have to check for it
if expr != nil && expr.Expr != nil {
Expand Down
27 changes: 24 additions & 3 deletions server/doltgres_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import (
"github.com/dolthub/go-mysql-server/server"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/types"
"github.com/dolthub/vitess/go/mysql"
Expand All @@ -45,6 +46,7 @@ import (
"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/postgres/parser/uuid"
pgexprs "github.com/dolthub/doltgresql/server/expression"
pgtransform "github.com/dolthub/doltgresql/server/transform"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

Expand Down Expand Up @@ -158,14 +160,33 @@ func (h *DoltgresHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, q
return nil, nil, err
}

analyzed, err := h.e.PrepareParsedQuery(sqlCtx, query, query, parsed)
node, err := h.e.PrepareParsedQuery(sqlCtx, query, query, parsed)
if err != nil {
if printErrorStackTraces {
fmt.Printf("unable to prepare query: %+v\n", err)
}
logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error())
err := sql.CastSQLError(err)
return nil, nil, err
return nil, nil, sql.CastSQLError(err)
}
analyzed := node
// We do not analyze expressions with bind variables, since that step comes later and analysis will return invalid results
hasBindVars := false
pgtransform.InspectNodeExprs(node, func(expr sql.Expression) bool {
if _, ok := expr.(*expression.BindVar); ok {
hasBindVars = true
return true
}
return false
})
if !hasBindVars {
analyzed, err = h.e.Analyzer.Analyze(sqlCtx, node, nil, nil)
if err != nil {
if printErrorStackTraces {
fmt.Printf("unable to prepare query: %+v\n", err)
}
logrus.WithField("query", query).Errorf("unable to prepare query: %s", err.Error())
return nil, nil, sql.CastSQLError(err)
}
}

var fields []pgproto3.FieldDescription
Expand Down
3 changes: 3 additions & 0 deletions server/types/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ func init() {
// by DoltgreSQL.
func SerializeType(extendedType sql.ExtendedType) ([]byte, error) {
if doltgresType, ok := extendedType.(*DoltgresType); ok {
if doltgresType.IsUnresolved {
return nil, errors.Errorf(`attempted to serialize the unresolved type: %s`, doltgresType.Name())
}
return doltgresType.Serialize(), nil
}
return nil, errors.Errorf("unknown type to serialize")
Expand Down
7 changes: 7 additions & 0 deletions testing/PostgresDockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ RUN apt update -y && \
postgresql-server-dev-15 && \
update-ca-certificates -f

# install rust (cargo)
ENV RUSTUP_HOME=/root/.rustup \
CARGO_HOME=/root/.cargo
ENV PATH="${CARGO_HOME}/bin:${PATH}"
RUN curl https://sh.rustup.rs -sSf | sh -s -- -y --profile minimal && \
rustc --version && cargo --version

# install go
WORKDIR /root
ENV GO_VERSION=1.25.0
Expand Down
16 changes: 15 additions & 1 deletion testing/go/framework.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ type ScriptTestAssertion struct {
// This is checked only if no Expected is defined
ExpectedTag string

// ExpectedColNames is used to check the column names returned from the server.
// ExpectedColNames are used to check the column names returned from the server.
ExpectedColNames []string

// ExpectedColTypes are used to check the column types returned from the server.
ExpectedColTypes []id.Type

// CopyFromSTDIN is used to test the COPY FROM STDIN command.
CopyFromStdInFile string
}
Expand Down Expand Up @@ -274,6 +277,17 @@ func runScript(t *testing.T, ctx context.Context, script ScriptTest, conn *Conne
}
}
}
if assertion.ExpectedColTypes != nil {
fields := rows.FieldDescriptions()
if assert.Len(t, fields, len(assertion.ExpectedColTypes),
"columns returned and types expected are not the same length") {
for i, colId := range assertion.ExpectedColTypes {
assert.Equal(t, id.Cache().ToOID(colId.AsId()), fields[i].DataTypeOID,
`"%s" expected type "%s" but received "%s"`, fields[i].Name,
colId.TypeName(), id.Type(id.Cache().ToInternal(fields[i].DataTypeOID)).TypeName())
}
}
}

// not an exact match but works well enough for our tests
orderBy := strings.Contains(strings.ToLower(assertion.Query), "order by")
Expand Down
41 changes: 41 additions & 0 deletions testing/go/issues_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ import (
"testing"

"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core/id"
pgtypes "github.com/dolthub/doltgresql/server/types"
)

func TestIssues(t *testing.T) {
Expand Down Expand Up @@ -366,5 +369,43 @@ limit 1`,
},
},
},
{
Name: "Issue #2299",
SetUpScript: []string{
"CREATE TYPE team_role AS ENUM ('admin', 'editor', 'member');",
},
Assertions: []ScriptTestAssertion{
{
Query: `CREATE TABLE users (id UUID PRIMARY KEY DEFAULT gen_random_uuid(), role team_role NOT NULL DEFAULT 'member');`,
Expected: []sql.Row{},
},
{
Query: `INSERT INTO users (role) VALUES (DEFAULT);`,
Expected: []sql.Row{},
},
{
Query: `SELECT role FROM users;`,
Expected: []sql.Row{{"member"}},
},
},
},
{
Name: "Issue #2307",
SetUpScript: []string{
"CREATE TABLE test (pk INT4);",
},
Assertions: []ScriptTestAssertion{
{
Query: `SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = 'test');`,
ExpectedColTypes: []id.Type{pgtypes.Bool.ID},
Expected: []sql.Row{{"t"}},
},
{
Query: `SELECT NOT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE tablename = 'test');`,
ExpectedColTypes: []id.Type{pgtypes.Bool.ID},
Expected: []sql.Row{{"f"}},
},
},
},
})
}
5 changes: 5 additions & 0 deletions testing/postgres-client-tests/postgres-client-tests.bats
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,8 @@ teardown() {
# this inserts and updates row and check its updated result
npx tsx src/index.ts
}

@test "rust sqlx" {
cd $BATS_TEST_DIRNAME/rust
cargo run -- $USER $PORT
}
Loading