diff --git a/core/init.go b/core/init.go index 87d454ead6..527f3f1551 100644 --- a/core/init.go +++ b/core/init.go @@ -20,6 +20,7 @@ import ( "github.com/dolthub/doltgresql/core/conflicts" "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/core/typecollection" "github.com/dolthub/doltgresql/server/plpgsql" gmstypes "github.com/dolthub/go-mysql-server/sql/types" @@ -35,4 +36,6 @@ func Init() { conflicts.ClearContextValues = ClearContextValues plpgsql.GetTypesCollectionFromContext = GetTypesCollectionFromContext id.RegisterListener(sequenceIDListener{}, id.Section_Table) + typecollection.GetSqlTableFromContext = GetSqlTableFromContext + typecollection.GetSchemaName = GetSchemaName } diff --git a/core/typecollection/typecollection.go b/core/typecollection/typecollection.go index 6d9dedd023..0687234201 100644 --- a/core/typecollection/typecollection.go +++ b/core/typecollection/typecollection.go @@ -25,6 +25,7 @@ import ( "github.com/dolthub/dolt/go/store/hash" "github.com/dolthub/dolt/go/store/prolly" "github.com/dolthub/dolt/go/store/prolly/tree" + "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/core/rootobject/objinterface" @@ -101,8 +102,10 @@ func (pgs *TypeCollection) DropType(ctx context.Context, names ...id.Type) (err } // GetAllTypes returns a map containing all types in the collection, grouped by the schema they're contained in. -// Each type array is also sorted by the type name. It includes built-in types. +// Each type array is also sorted by the type name. It includes built-in types, but does not include types referring to +// a table's row type. func (pgs *TypeCollection) GetAllTypes(ctx context.Context) (typeMap map[string][]*pgtypes.DoltgresType, schemaNames []string, totalCount int, err error) { + // TODO: this should probably get tables as well since tables create composite types matching their rows schemaNamesMap := make(map[string]struct{}) typeMap = make(map[string][]*pgtypes.DoltgresType) err = pgs.IterateTypes(ctx, func(t *pgtypes.DoltgresType) (stop bool, err error) { @@ -158,9 +161,21 @@ func (pgs *TypeCollection) GetType(ctx context.Context, name id.Type) (*pgtypes. } // The initial load is from the internal map h, err := pgs.underlyingMap.Get(ctx, string(name)) - if err != nil || h.IsEmpty() { + if err != nil { return nil, err } + if h.IsEmpty() { + // If it's not a built-in type or created type, then check if it's a composite table row type + sqlCtx, ok := ctx.(*sql.Context) + if !ok { + return nil, nil + } + tbl, schema, err := pgs.getTable(sqlCtx, name.SchemaName(), name.TypeName()) + if err != nil || tbl == nil { + return nil, err + } + return pgs.tableToType(sqlCtx, tbl, schema) + } data, err := pgs.ns.ReadBytes(ctx, h) if err != nil { return nil, err @@ -180,7 +195,7 @@ func (pgs *TypeCollection) HasType(ctx context.Context, name id.Type) bool { if _, ok := pgtypes.IDToBuiltInDoltgresType[name]; ok { return true } - + // Now we'll check our created types if _, ok := pgs.accessedMap[name]; ok { return true } @@ -188,11 +203,18 @@ func (pgs *TypeCollection) HasType(ctx context.Context, name id.Type) bool { if err == nil && ok { return true } - return false + // If it's not a built-in type or created type, then check if it's a composite table row type + sqlCtx, ok := ctx.(*sql.Context) + if !ok { + return false + } + tbl, _, err := pgs.getTable(sqlCtx, name.SchemaName(), name.TypeName()) + return err == nil && tbl != nil } // resolveName returns the fully resolved name of the given type. Returns an error if the name is ambiguous. func (pgs *TypeCollection) resolveName(ctx context.Context, schemaName string, typeName string) (id.Type, error) { + // TODO: this should probably check table names as well since tables create composite types matching their rows // First check for an exact match in the built-in types inputID := id.NewType(schemaName, typeName) if _, ok := pgtypes.IDToBuiltInDoltgresType[inputID]; ok { @@ -251,6 +273,7 @@ func (pgs *TypeCollection) resolveName(ctx context.Context, schemaName string, t // IterateTypes iterates over all types in the collection. func (pgs *TypeCollection) IterateTypes(ctx context.Context, f func(typ *pgtypes.DoltgresType) (stop bool, err error)) error { + // TODO: this should probably iterate tables as well since tables create composite types matching their rows // We can iterate the built-in types first for _, t := range pgtypes.GetAllBuitInTypes() { stop, err := f(t) @@ -368,3 +391,47 @@ func (pgs *TypeCollection) writeCache(ctx context.Context) (err error) { clear(pgs.accessedMap) return nil } + +// getTable returns the SQL table that matches the given schema and table name. Returns a nil table if one is not found. +// This is intended for use with tableToType. +func (*TypeCollection) getTable(ctx *sql.Context, schema string, tblName string) (tbl sql.Table, actualSchema string, err error) { + actualSchema, err = GetSchemaName(ctx, nil, schema) + tbl, err = GetSqlTableFromContext(ctx, "", doltdb.TableName{ + Name: tblName, + Schema: actualSchema, + }) + if err != nil || tbl == nil { + return nil, "", err + } + if schTbl, ok := tbl.(sql.DatabaseSchemaTable); ok { + actualSchema = schTbl.DatabaseSchema().SchemaName() + } + return tbl, actualSchema, nil +} + +// tableToType handles type creation related to a table's composite row type. +// https://www.postgresql.org/docs/15/sql-createtable.html +func (*TypeCollection) tableToType(ctx *sql.Context, tbl sql.Table, schema string) (*pgtypes.DoltgresType, error) { + tblName := tbl.Name() + tblSch := tbl.Schema() + typeID := id.NewType(schema, tblName) + relID := id.NewTable(schema, tblName).AsId() + arrayID := id.NewType(schema, "_"+tblName) + attrs := make([]pgtypes.CompositeAttribute, len(tblSch)) + for i, col := range tblSch { + collation := "" // TODO: what should we use for the collation? + colType, ok := col.Type.(*pgtypes.DoltgresType) + if !ok { + // TODO: perhaps we should use a better error message stating that it uses a non-Doltgres type? + return nil, pgtypes.ErrTypeDoesNotExist.New(tblName) + } + attrs[i] = pgtypes.NewCompositeAttribute(ctx, relID, col.Name, colType.ID, int16(i+1), collation) + } + return pgtypes.NewCompositeType(ctx, relID, arrayID, typeID, attrs), nil +} + +// GetSqlTableFromContext is a forward declaration to get around import cycles +var GetSqlTableFromContext func(ctx *sql.Context, databaseName string, tableName doltdb.TableName) (sql.Table, error) + +// GetSchemaName is a forward declaration to get around import cycles +var GetSchemaName func(ctx *sql.Context, db sql.Database, schemaName string) (string, error) diff --git a/server/expression/explicit_cast.go b/server/expression/explicit_cast.go index 892ea74faa..9b4ac18502 100644 --- a/server/expression/explicit_cast.go +++ b/server/expression/explicit_cast.go @@ -23,6 +23,7 @@ import ( "github.com/dolthub/go-mysql-server/sql/expression" vitess "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/server/functions/framework" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -98,6 +99,49 @@ func (c *ExplicitCast) Eval(ctx *sql.Context, row sql.Row) (any, error) { if castFunction == nil { if fromType.ID == pgtypes.Unknown.ID { castFunction = framework.UnknownLiteralCast + } else if fromType.IsRecordType() && c.castToType.IsCompositeType() { // TODO: should this only be in explicit, or assignment and implicit too? + // Casting to a record type will always work for any composite type. + // TODO: is the above statement true for all cases? + // When casting to a composite type, then we must match the arity and have valid casts for every position. + if c.castToType.IsRecordType() { + castFunction = framework.IdentityCast + } else { + castFunction = func(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { + vals, ok := val.([]pgtypes.RecordValue) + if !ok { + // TODO: better error message + return nil, errors.New("casting input error from record type") + } + if len(targetType.CompositeAttrs) != len(vals) { + return nil, errors.Newf("cannot cast type %s to %s", "", targetType.Name()) + } + typeCollection, err := core.GetTypesCollectionFromContext(ctx) + if err != nil { + return nil, err + } + outputVals := make([]pgtypes.RecordValue, len(vals)) + for i := range vals { + valType, ok := vals[i].Type.(*pgtypes.DoltgresType) + if !ok { + // TODO: if this is a GMS type, then we should cast to a Doltgres type here + return nil, errors.New("cannot cast record containing GMS type") + } + outputVals[i].Type, err = typeCollection.GetType(ctx, targetType.CompositeAttrs[i].TypeID) + if err != nil { + return nil, err + } + innerExplicit := ExplicitCast{ + sqlChild: NewUnsafeLiteral(vals[i].Value, valType), + castToType: outputVals[i].Type.(*pgtypes.DoltgresType), + } + outputVals[i].Value, err = innerExplicit.Eval(ctx, nil) + if err != nil { + return nil, err + } + } + return outputVals, nil + } + } } else { return nil, errors.Errorf("EXPLICIT CAST: cast from `%s` to `%s` does not exist: %s", fromType.String(), c.castToType.String(), c.sqlChild.String()) diff --git a/server/functions/framework/cast.go b/server/functions/framework/cast.go index 40a6ca3bae..8292c59a4f 100644 --- a/server/functions/framework/cast.go +++ b/server/functions/framework/cast.go @@ -134,7 +134,7 @@ func GetExplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp // parameters). If one of the types are a string type, then we do not use the identity, and use the I/O conversions // below. if fromType.ID == toType.ID && toType.TypCategory != pgtypes.TypeCategory_StringTypes && fromType.TypCategory != pgtypes.TypeCategory_StringTypes { - return identityCast + return IdentityCast } // All types have a built-in explicit cast from string types: https://www.postgresql.org/docs/15/sql-createcast.html if fromType.TypCategory == pgtypes.TypeCategory_StringTypes { @@ -175,7 +175,7 @@ func GetAssignmentCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresT // We check for the identity after checking the maps, as the identity may be overridden (such as for types that have // parameters). If the "to" type is a string type, then we do not use the identity, and use the I/O conversion below. if fromType.ID == toType.ID && fromType.TypCategory != pgtypes.TypeCategory_StringTypes { - return identityCast + return IdentityCast } // All types have a built-in assignment cast to string types: https://www.postgresql.org/docs/15/sql-createcast.html if toType.TypCategory == pgtypes.TypeCategory_StringTypes { @@ -202,7 +202,7 @@ func GetImplicitCast(fromType *pgtypes.DoltgresType, toType *pgtypes.DoltgresTyp // We check for the identity after checking the maps, as the identity may be overridden (such as for types that have // parameters). if fromType.ID == toType.ID { - return identityCast + return IdentityCast } return nil } @@ -282,8 +282,8 @@ func getCast(mutex *sync.RWMutex, return nil } -// identityCast returns the input value. -func identityCast(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { +// IdentityCast returns the input value. +func IdentityCast(ctx *sql.Context, val any, targetType *pgtypes.DoltgresType) (any, error) { return val, nil } diff --git a/server/functions/framework/compiled_function.go b/server/functions/framework/compiled_function.go index 48fac3c7fd..19e84f8c5c 100644 --- a/server/functions/framework/compiled_function.go +++ b/server/functions/framework/compiled_function.go @@ -489,7 +489,7 @@ func (c *CompiledFunction) resolveOperator(argTypes []*pgtypes.DoltgresType, ove rightUnknownType := argTypes[1].ID == pgtypes.Unknown.ID if (leftUnknownType && !rightUnknownType) || (!leftUnknownType && rightUnknownType) { var typ *pgtypes.DoltgresType - casts := []pgtypes.TypeCastFunction{identityCast, identityCast} + casts := []pgtypes.TypeCastFunction{IdentityCast, IdentityCast} if leftUnknownType { casts[0] = UnknownLiteralCast typ = argTypes[1] @@ -577,7 +577,7 @@ func (c *CompiledFunction) typeCompatibleOverloads(fnOverloads []Overload, argTy for i := range argTypes { paramType := overload.argTypes[i] if paramType.IsValidForPolymorphicType(argTypes[i]) { - overloadCasts[i] = identityCast + overloadCasts[i] = IdentityCast polymorphicParameters = append(polymorphicParameters, paramType) polymorphicTargets = append(polymorphicTargets, argTypes[i]) } else { diff --git a/server/functions/record.go b/server/functions/record.go index 377116f2e1..890148760c 100644 --- a/server/functions/record.go +++ b/server/functions/record.go @@ -81,6 +81,8 @@ var record_send = framework.Function1{ if !ok { return nil, fmt.Errorf("expected []RecordValue, but got %T", val) } + // TODO: converting from a string back to the record doesn't work as we lose type information, so we need to + // figure out how to retain this information output, err := pgtypes.RecordToString(ctx, values) if err != nil { return nil, err diff --git a/server/node/create_type.go b/server/node/create_type.go index 66a92554b3..338d175e3c 100644 --- a/server/node/create_type.go +++ b/server/node/create_type.go @@ -128,7 +128,10 @@ func (c *CreateType) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, error) { newType = types.NewEnumType(ctx, arrayID, typeID, enumLabelMap) // TODO: store labels somewhere case types.TypeType_Composite: - relID := id.Null // TODO: create relation with c.AsTypes + // TODO: non-composite types have a zero oid for their relID, which for us would be a null ID. + // We need to find a way to distinguish a null ID from a composite type that does not reference a table + // (which is what relID points to if it represents a table row's composite type) + relID := id.Null attrs := make([]types.CompositeAttribute, len(c.AsTypes)) for i, a := range c.AsTypes { attrs[i] = types.NewCompositeAttribute(ctx, relID, a.AttrName, a.Typ.ID, int16(i+1), a.Collation) diff --git a/server/plpgsql/interpreter_stack.go b/server/plpgsql/interpreter_stack.go index 1228ff4d9b..87ee8bd347 100644 --- a/server/plpgsql/interpreter_stack.go +++ b/server/plpgsql/interpreter_stack.go @@ -28,7 +28,7 @@ import ( // interacted with. InterpreterVariableReference are, instead, the avenue of interaction as a variable may be an // aggregate type (such as a record). type interpreterVariable struct { - Record sql.Schema + Record sql.Schema // TODO: all records carry their type information alongside the value, so this is redundant Type *pgtypes.DoltgresType Value any } @@ -119,6 +119,18 @@ func (is *InterpreterStack) GetVariable(name string) InterpreterVariableReferenc Type: iv.Record[fieldIdx].Type.(*pgtypes.DoltgresType), Value: &(iv.Value.(sql.Row)[fieldIdx]), } + } else if iv.Type.IsCompositeType() { + for fieldIdx := range iv.Type.CompositeAttrs { + if iv.Type.CompositeAttrs[fieldIdx].Name == fieldName { + vals := iv.Value.([]pgtypes.RecordValue) + return InterpreterVariableReference{ + Type: vals[fieldIdx].Type.(*pgtypes.DoltgresType), + Value: &(vals[fieldIdx].Value), + } + } + } + // The field could not be found + return InterpreterVariableReference{} } else { // Can't access fields on an empty record return InterpreterVariableReference{} diff --git a/server/plpgsql/statements.go b/server/plpgsql/statements.go index 2a350eb6d2..c49a004b02 100644 --- a/server/plpgsql/statements.go +++ b/server/plpgsql/statements.go @@ -357,17 +357,12 @@ func substituteVariableReferences(expression string, stack *InterpreterStack) (n token := scanResult.Tokens[i] substring := expression[token.Start:token.End] // varMap lowercases everything, so we'll lowercase our substring to enable case-insensitivity - if fieldNames, ok := varMap[strings.ToLower(substring)]; ok { - // If there's a '.', then we'll check if this is accessing a record's field (`NEW.val1` for example) - if len(fieldNames) > 0 && i+2 < len(scanResult.Tokens) && scanResult.Tokens[i+1].Token == '.' { - possibleFieldSubstring := expression[scanResult.Tokens[i+2].Start:scanResult.Tokens[i+2].End] - for _, fieldName := range fieldNames { - if fieldName == strings.ToLower(possibleFieldSubstring) { - substring += "." + fieldName - i += 2 - break - } - } + if _, ok := varMap[strings.ToLower(substring)]; ok { + // If there's a '.', then we'll assume this is accessing a record's field (`NEW.val1` for example) + for i+2 < len(scanResult.Tokens) && scanResult.Tokens[i+1].Token == '.' { + nextFieldSubstring := expression[scanResult.Tokens[i+2].Start:scanResult.Tokens[i+2].End] + substring += "." + nextFieldSubstring + i += 2 } // Variables cannot have a '(' after their name as that would classify them as functions, so we have to // explicitly check for that. This is because variables and functions can share names, for example: diff --git a/server/types/composite.go b/server/types/composite.go index 20fbc62b96..ff67b8db0f 100644 --- a/server/types/composite.go +++ b/server/types/composite.go @@ -62,20 +62,20 @@ func NewCompositeType(ctx *sql.Context, relID id.Id, arrayID, typeID id.Type, at // CompositeAttribute represents a composite type attribute. // This is a partial pg_attribute row entry. type CompositeAttribute struct { - relID id.Id // ID of the relation it belongs to - name string - typeID id.Type // ID of DoltgresType - num int16 // number of the column in relation - collation string + RelID id.Id // ID of the relation it belongs to + Name string + TypeID id.Type // ID of DoltgresType + Num int16 // 1-based number of the column in relation + Collation string } -// NewCompositeAttribute creates new instance of composite type attribute. +// NewCompositeAttribute creates new instance of composite type attribute. `num` is 1-based rather than 0-based. func NewCompositeAttribute(ctx *sql.Context, relID id.Id, name string, typeID id.Type, num int16, collation string) CompositeAttribute { return CompositeAttribute{ - relID: relID, - name: name, - typeID: typeID, - num: num, - collation: collation, + RelID: relID, + Name: name, + TypeID: typeID, + Num: num, + Collation: collation, } } diff --git a/server/types/serialization.go b/server/types/serialization.go index 44d61df41d..908491cfce 100644 --- a/server/types/serialization.go +++ b/server/types/serialization.go @@ -122,11 +122,11 @@ func DeserializeType(serializedType []byte) (sql.ExtendedType, error) { num := reader.Int16() collation := reader.String() typ.CompositeAttrs[k] = CompositeAttribute{ - relID: relID, - name: name, - typeID: id.Type(typeID), - num: num, - collation: collation, + RelID: relID, + Name: name, + TypeID: id.Type(typeID), + Num: num, + Collation: collation, } } } @@ -196,11 +196,11 @@ func (t *DoltgresType) Serialize() []byte { writer.VariableUint(uint64(len(t.CompositeAttrs))) if len(t.CompositeAttrs) > 0 { for _, l := range t.CompositeAttrs { - writer.Id(l.relID) - writer.String(l.name) - writer.Id(l.typeID.AsId()) - writer.Int16(l.num) - writer.String(l.collation) + writer.Id(l.RelID) + writer.String(l.Name) + writer.Id(l.TypeID.AsId()) + writer.Int16(l.Num) + writer.String(l.Collation) } } writer.String(t.InternalName) diff --git a/server/types/type.go b/server/types/type.go index 4b029cb282..be6776a9ac 100644 --- a/server/types/type.go +++ b/server/types/type.go @@ -520,7 +520,12 @@ func (t *DoltgresType) IsArrayType() bool { // IsCompositeType returns true if the type is a composite type, such as an anonymous record, or a // user-created composite type. func (t *DoltgresType) IsCompositeType() bool { - return t.ID.TypeName() == "record" || t.TypType == TypeType_Composite + return t.TypType == TypeType_Composite || t.IsRecordType() +} + +// IsRecordType returns true if the type is an anonymous record type. +func (t *DoltgresType) IsRecordType() bool { + return t.TypType == TypeType_Pseudo && t.ID.TypeName() == "record" } // IsEmptyType returns true if the type is not valid. @@ -931,6 +936,8 @@ func (t *DoltgresType) DeserializeValue(ctx context.Context, val []byte) (any, e } } else if t.TypType == TypeType_Enum { return globalFunctionRegistry.GetFunction(t.ReceiveFunc).CallVariadic(nil, val, t.ID.AsId()) + } else if t.IsCompositeType() { + return globalFunctionRegistry.GetFunction(t.ReceiveFunc).CallVariadic(nil, val, t.ID.AsId(), t.attTypMod) } else { return globalFunctionRegistry.GetFunction(t.ReceiveFunc).CallVariadic(nil, val) } diff --git a/testing/go/create_function_plpgsql_test.go b/testing/go/create_function_plpgsql_test.go index 15f9c06bd9..a2623a5faa 100644 --- a/testing/go/create_function_plpgsql_test.go +++ b/testing/go/create_function_plpgsql_test.go @@ -931,5 +931,72 @@ $$;`, }, }, }, + { + Name: "Table as type for functions", + SetUpScript: []string{ + // TODO: test case sensitivity of parameter names + `CREATE TABLE test (id INT4 PRIMARY KEY, name TEXT NOT NULL, qty INT4 NOT NULL, price REAL NOT NULL);`, + `INSERT INTO test VALUES (1, 'apple', 3, 2.5), (2, 'banana', 5, 1.2);`, + `CREATE FUNCTION total(t test) RETURNS REAL AS $$ BEGIN RETURN t.qty * t.price; END; $$ LANGUAGE plpgsql;`, + `CREATE FUNCTION priceHike(t test, pricehike REAL) RETURNS test AS $$ BEGIN RETURN (t.id, t.name, t.qty, t.price + pricehike)::test; END; $$ LANGUAGE plpgsql;`, + `CREATE FUNCTION singleReturn() RETURNS test AS $$ DECLARE result test; BEGIN SELECT * INTO result FROM test WHERE id = 1; RETURN result; END; $$ LANGUAGE plpgsql;`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT total(t) FROM test AS t;`, + Skip: true, // TODO: cannot take the table itself as a parameter yet, perhaps as an aggregation func? + Expected: []sql.Row{ + {7.5}, + {6.0}, + }, + }, + { + Query: `SELECT priceHike(t, 10.0) FROM test AS t;`, + Skip: true, // TODO: cannot take the table itself as a parameter yet, perhaps as an aggregation func? + Expected: []sql.Row{ + {"(1,apple,3,12.5)"}, + {"(2,banana,5,11.2)"}, + }, + }, + { + Query: `SELECT priceHike(ROW(3, 'orange', 1, 1.8)::test, 100.0);`, + Expected: []sql.Row{ + {"(3,orange,1,101.8)"}, + }, + }, + { + Query: `SELECT singleReturn();`, + Skip: true, // TODO: better PL/pgSQL internal support for non-trigger composite types + Expected: []sql.Row{ + {"(1,apple,3,2.5)"}, + }, + }, + }, + }, + { + Name: "Table as type for columns", + Focus: true, + SetUpScript: []string{ + `CREATE TABLE t1 (v1 INT4 PRIMARY KEY, v2 TEXT NOT NULL);`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `CREATE TABLE t2 (v1 INT4 PRIMARY KEY, v2 t1 NOT NULL);`, + Expected: []sql.Row{}, + }, + { + Query: `INSERT INTO t2 VALUES (1, ROW(0, 'hello')::t1), (2, ROW(10, 'world')::t1);`, + Expected: []sql.Row{}, + }, + { + Query: `SELECT * FROM t2 ORDER BY v1;`, + Skip: true, // TODO: need to implement record_recv + Expected: []sql.Row{ + {1, "(0,hello)"}, + {2, "(10,world)"}, + }, + }, + }, + }, }) } diff --git a/testing/go/import_dumps_test.go b/testing/go/import_dumps_test.go index d4cc96914d..6636b0611f 100644 --- a/testing/go/import_dumps_test.go +++ b/testing/go/import_dumps_test.go @@ -33,7 +33,7 @@ import ( // TestImportingDumps are regression tests against dumps taken from various sources. func TestImportingDumps(t *testing.T) { - t.Skip("The majority fail for now") + //t.Skip("The majority fail for now") RunImportTests(t, []ImportTest{ { Name: "Scrubbed-1", @@ -54,6 +54,7 @@ func TestImportingDumps(t *testing.T) { }, { Name: "AlexTransit/venderctl", + Focus: true, SQLFilename: "AlexTransit_venderctl.sql", }, { diff --git a/testing/go/trigger_test.go b/testing/go/trigger_test.go index 956bbb5197..0fba12954b 100755 --- a/testing/go/trigger_test.go +++ b/testing/go/trigger_test.go @@ -639,5 +639,36 @@ $$ LANGUAGE plpgsql;`, }, }, }, + { + Name: "Table as type", + Focus: true, + SetUpScript: []string{ + `CREATE TABLE test (id INT4 PRIMARY KEY, name TEXT NOT NULL, qty INT4 NOT NULL, price REAL NOT NULL);`, + `CREATE FUNCTION trigger_func() RETURNS trigger AS $$ +DECLARE + rec test; +BEGIN + rec := NEW; + IF rec.qty < 0 THEN + rec.qty := -rec.qty; + END IF; + NEW := rec; + RETURN NEW; +END; $$ LANGUAGE plpgsql;`, + `CREATE TRIGGER test_trigger BEFORE INSERT ON test FOR EACH ROW EXECUTE FUNCTION trigger_func();`, + }, + Assertions: []ScriptTestAssertion{ + { + Query: `INSERT INTO test VALUES (1, 'apple', 3, 2.5), (2, 'banana', -5, -1.2);`, + Expected: []sql.Row{}, + }, { + Query: `SELECT * FROM test;`, + Expected: []sql.Row{ + {1, "apple", 3, 2.5}, + {2, "banana", 5, -1.2}, + }, + }, + }, + }, }) }