Skip to content

Commit cf06db4

Browse files
committed
Add schema name to views for doltgres
1 parent ec0bc7c commit cf06db4

File tree

4 files changed

+96
-31
lines changed

4 files changed

+96
-31
lines changed

go/libraries/doltcore/doltdb/system_table.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,9 @@ const (
286286
// SchemasTablesSqlModeCol is the name of the column that stores the SQL_MODE string used when this fragment
287287
// was originally defined. Mode settings, such as ANSI_QUOTES, are needed to correctly parse the fragment.
288288
SchemasTablesSqlModeCol = "sql_mode"
289+
// SchemasTablesSchemaNameCol is the name of the column that stores the name of the schema that the fragment
290+
// is part of. Used by Doltgres only.
291+
SchemasTablesSchemaNameCol = "schema_name"
289292
)
290293

291294
const (

go/libraries/doltcore/schema/reserved_tags.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ const (
8080
DoltSchemasFragmentTag
8181
DoltSchemasExtraTag
8282
DoltSchemasSqlModeTag
83+
DoltSchemasSchemaNameTag
8384
)
8485

8586
// Tags for hidden columns in keyless rows

go/libraries/doltcore/sqle/database.go

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,6 +1758,7 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie
17581758
}
17591759
}
17601760

1761+
schemaName := db.schemaName
17611762
lwrViewName := strings.ToLower(viewName)
17621763
switch {
17631764
case strings.HasPrefix(lwrViewName, doltdb.DoltBlameViewPrefix):
@@ -1767,7 +1768,12 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie
17671768
if err != nil {
17681769
return sql.ViewDefinition{}, false, err
17691770
}
1770-
return sql.ViewDefinition{Name: viewName, TextDefinition: blameViewTextDef, CreateViewStatement: fmt.Sprintf("CREATE VIEW `%s` AS %s", viewName, blameViewTextDef)}, true, nil
1771+
return sql.ViewDefinition{
1772+
Name: viewName,
1773+
SchemaName: db.schemaName,
1774+
TextDefinition: blameViewTextDef,
1775+
CreateViewStatement: fmt.Sprintf("CREATE VIEW `%s` AS %s", viewName, blameViewTextDef)},
1776+
true, nil
17711777
}
17721778

17731779
schemasTableName := getDoltSchemasTableName()
@@ -1809,22 +1815,22 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie
18091815
}
18101816

18111817
if wrapper.backingTable == nil {
1812-
dbState.SessionCache().CacheViews(key, nil, db.schemaName)
1818+
dbState.SessionCache().CacheViews(key, nil, schemaName)
18131819
return sql.ViewDefinition{}, false, nil
18141820
}
18151821

1816-
views, viewDef, found, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, viewName)
1822+
views, viewDef, found, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, viewName, schemaName)
18171823
if err != nil {
18181824
return sql.ViewDefinition{}, false, err
18191825
}
18201826

18211827
// TODO: only cache views from a single schema here
1822-
dbState.SessionCache().CacheViews(key, views, db.schemaName)
1828+
dbState.SessionCache().CacheViews(key, views, schemaName)
18231829

18241830
return viewDef, found, nil
18251831
}
18261832

1827-
func getViewDefinitionFromSchemaFragmentsOfView(ctx *sql.Context, tbl *WritableDoltTable, viewName string) ([]sql.ViewDefinition, sql.ViewDefinition, bool, error) {
1833+
func getViewDefinitionFromSchemaFragmentsOfView(ctx *sql.Context, tbl *WritableDoltTable, viewName, schemaName string) ([]sql.ViewDefinition, sql.ViewDefinition, bool, error) {
18281834
fragments, err := getSchemaFragmentsOfType(ctx, tbl, viewFragment)
18291835
if err != nil {
18301836
return nil, sql.ViewDefinition{}, false, err
@@ -1843,14 +1849,15 @@ func getViewDefinitionFromSchemaFragmentsOfView(ctx *sql.Context, tbl *WritableD
18431849
}
18441850
} else {
18451851
views[i] = sql.ViewDefinition{
1846-
Name: fragments[i].name,
1852+
Name: fragments[i].name,
1853+
SchemaName: fragments[i].schemaName,
18471854
// TODO: need to define TextDefinition
18481855
CreateViewStatement: fragments[i].fragment,
18491856
SqlMode: fragment.sqlMode,
18501857
}
18511858
}
18521859

1853-
if strings.EqualFold(fragment.name, viewName) {
1860+
if strings.EqualFold(fragment.name, viewName) && strings.EqualFold(fragment.schemaName, schemaName) {
18541861
found = true
18551862
viewDef = views[i]
18561863
}
@@ -1878,7 +1885,7 @@ func (db Database) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) {
18781885
return nil, nil
18791886
}
18801887

1881-
views, _, _, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, "")
1888+
views, _, _, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, "", "")
18821889
if err != nil {
18831890
return nil, err
18841891
}
@@ -1974,7 +1981,7 @@ func (db Database) CreateTrigger(ctx *sql.Context, definition sql.TriggerDefinit
19741981
definition.Name,
19751982
definition.CreateStatement,
19761983
definition.CreatedAt,
1977-
fmt.Errorf("triggers `%s` already exists", definition.Name), //TODO: add a sql error and return that instead
1984+
fmt.Errorf("triggers `%s` already exists", definition.Name), // TODO: add a sql error and return that instead
19781985
)
19791986
}
19801987

@@ -2236,7 +2243,7 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
22362243
return err
22372244
}
22382245

2239-
_, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name)
2246+
_, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name, db.schemaName)
22402247
if err != nil {
22412248
return err
22422249
}
@@ -2263,14 +2270,34 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
22632270

22642271
sqlMode := sql.LoadSqlMode(ctx)
22652272

2266-
return inserter.Insert(ctx, sql.Row{fragType, name, definition, extraJSON, sqlMode.String()})
2273+
row := sql.Row{fragType, name, definition, extraJSON, sqlMode.String()}
2274+
2275+
// Include schema_name column for doltgres
2276+
if resolve.UseSearchPath && tbl.Schema().Contains(doltdb.SchemasTablesSchemaNameCol, tbl.Name()) {
2277+
if db.schemaName == "" {
2278+
root, err := db.GetRoot(ctx)
2279+
if err != nil {
2280+
return err
2281+
}
2282+
schemaName, err := resolve.FirstExistingSchemaOnSearchPath(ctx, root)
2283+
if err != nil {
2284+
return err
2285+
}
2286+
db.schemaName = schemaName
2287+
}
2288+
2289+
row = sql.Row{fragType, name, db.schemaName, definition, extraJSON, sqlMode.String()}
2290+
}
2291+
2292+
return inserter.Insert(ctx, row)
22672293
}
22682294

22692295
func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name string, missingErr error) error {
22702296
if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil {
22712297
return err
22722298
}
22732299

2300+
schemaName := db.schemaName
22742301
if resolve.UseSearchPath {
22752302
db.schemaName = "dolt"
22762303
}
@@ -2288,14 +2315,26 @@ func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name str
22882315
return missingErr
22892316
}
22902317

2318+
if resolve.UseSearchPath && schemaName == "" {
2319+
root, err := db.GetRoot(ctx)
2320+
if err != nil {
2321+
return err
2322+
}
2323+
schemaName, err = resolve.FirstExistingSchemaOnSearchPath(ctx, root)
2324+
if err != nil {
2325+
return err
2326+
}
2327+
}
2328+
22912329
tbl := swrapper.backingTable
2292-
row, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name)
2330+
row, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name, schemaName)
22932331
if err != nil {
22942332
return err
22952333
}
22962334
if !exists {
22972335
return missingErr
22982336
}
2337+
22992338
deleter := tbl.Deleter(ctx)
23002339
err = deleter.Delete(ctx, row)
23012340
if err != nil {

go/libraries/doltcore/sqle/schema_table.go

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,15 @@ func (st *SchemaTable) String() string {
5555
return doltdb.GetSchemasTableName()
5656
}
5757

58+
// GetSchemasSchema returns the schema of the dolt_schemas system table. This is used
59+
// by Doltgres to update the dolt_schemas schema with an additional schema_name column.
60+
var GetSchemasSchema = SchemaTableSchema
61+
5862
func (st *SchemaTable) Schema() sql.Schema {
63+
currentSchema := toSqlSchemaTableSchema(GetSchemasSchema())
5964
if st.backingTable == nil {
6065
// No backing table; return a current schema.
61-
return SchemaTableSqlSchema().Schema
66+
return currentSchema.Schema
6267
}
6368

6469
if !st.backingTable.Schema().Contains(doltdb.SchemasTablesExtraCol, doltdb.GetSchemasTableName()) {
@@ -71,7 +76,7 @@ func (st *SchemaTable) Schema() sql.Schema {
7176
return SchemaTableV1SqlSchema()
7277
}
7378

74-
return SchemaTableSqlSchema().Schema
79+
return currentSchema.Schema
7580
}
7681

7782
func (st *SchemaTable) Collation() sql.CollationID {
@@ -127,8 +132,8 @@ var _ sql.IndexAddressableTable = (*SchemaTable)(nil)
127132
var _ sql.UpdatableTable = (*SchemaTable)(nil)
128133
var _ WritableDoltTableWrapper = (*SchemaTable)(nil)
129134

130-
func SchemaTableSqlSchema() sql.PrimaryKeySchema {
131-
sqlSchema, err := sqlutil.FromDoltSchema("", doltdb.GetSchemasTableName(), SchemaTableSchema())
135+
func toSqlSchemaTableSchema(sch schema.Schema) sql.PrimaryKeySchema {
136+
sqlSchema, err := sqlutil.FromDoltSchema("", doltdb.GetSchemasTableName(), sch)
132137
if err != nil {
133138
panic(err) // should never happen
134139
}
@@ -250,7 +255,7 @@ func getOrCreateDoltSchemasTable(ctx *sql.Context, db Database) (retTbl *Writabl
250255
}
251256

252257
// Create new empty table
253-
err = db.createDoltTable(ctx, tname, root, SchemaTableSchema())
258+
err = db.createDoltTable(ctx, tname, root, GetSchemasSchema())
254259
if err != nil {
255260
return nil, err
256261
}
@@ -367,8 +372,8 @@ func migrateOldSchemasTableToNew(ctx *sql.Context, db Database, schemasTable *Wr
367372
}
368373

369374
// fragFromSchemasTable returns the row with the given schema fragment if it exists.
370-
func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType string, name string) (r sql.Row, found bool, rerr error) {
371-
fragType, name = strings.ToLower(fragType), strings.ToLower(name)
375+
func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType, name, schemaName string) (r sql.Row, found bool, rerr error) {
376+
fragType, name, schemaName = strings.ToLower(fragType), strings.ToLower(name), strings.ToLower(schemaName)
372377

373378
// This performs a full table scan in the worst case, but it's only used when adding or dropping a trigger or view
374379
iter, err := SqlTableToRowIter(ctx, tbl.DoltTable, nil)
@@ -387,6 +392,7 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
387392
// need to get the column indexes from the current schema
388393
nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
389394
typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
395+
schemaNameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesSchemaNameCol)
390396

391397
for {
392398
sqlRow, err := iter.Next(ctx)
@@ -397,8 +403,13 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
397403
return nil, false, err
398404
}
399405

406+
sqlRowSchemaName := ""
407+
if schemaNameIdx >= 0 {
408+
sqlRowSchemaName = sqlRow[schemaNameIdx].(string)
409+
}
410+
400411
// These columns are case insensitive, make sure to do a case-insensitive comparison
401-
if strings.EqualFold(sqlRow[typeIdx].(string), fragType) && strings.EqualFold(sqlRow[nameIdx].(string), name) {
412+
if strings.EqualFold(sqlRow[typeIdx].(string), fragType) && strings.EqualFold(sqlRow[nameIdx].(string), name) && strings.EqualFold(sqlRowSchemaName, schemaName) {
402413
return sqlRow, true, nil
403414
}
404415
}
@@ -407,9 +418,10 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
407418
}
408419

409420
type schemaFragment struct {
410-
name string
411-
fragment string
412-
created time.Time
421+
name string
422+
schemaName string
423+
fragment string
424+
created time.Time
413425
// sqlMode indicates the SQL_MODE that was used when this schema fragment was initially parsed. SQL_MODE settings
414426
// such as ANSI_QUOTES control customized parsing behavior needed for some schema fragments.
415427
sqlMode string
@@ -424,6 +436,7 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
424436
// The dolt_schemas table has undergone various changes over time and multiple possible schemas for it exist, so we
425437
// need to get the column indexes from the current schema
426438
nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
439+
schemaNameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesSchemaNameCol)
427440
typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
428441
fragmentIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesFragmentCol)
429442
extraIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesExtraCol)
@@ -463,13 +476,21 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
463476
sqlModeString = defaultSqlMode
464477
}
465478

479+
schemaNameString := ""
480+
if schemaNameIdx >= 0 {
481+
if s, ok := sqlRow[schemaNameIdx].(string); ok {
482+
schemaNameString = s
483+
}
484+
}
485+
466486
// For older tables, use 1 as the trigger creation time
467487
if extraIdx < 0 || sqlRow[extraIdx] == nil {
468488
frags = append(frags, schemaFragment{
469-
name: sqlRow[nameIdx].(string),
470-
fragment: sqlRow[fragmentIdx].(string),
471-
created: time.Unix(1, 0).UTC(), // TablePlus editor thinks 0 is out of range
472-
sqlMode: sqlModeString,
489+
name: sqlRow[nameIdx].(string),
490+
schemaName: schemaNameString,
491+
fragment: sqlRow[fragmentIdx].(string),
492+
created: time.Unix(1, 0).UTC(), // TablePlus editor thinks 0 is out of range
493+
sqlMode: sqlModeString,
473494
})
474495
continue
475496
}
@@ -481,10 +502,11 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
481502
}
482503

483504
frags = append(frags, schemaFragment{
484-
name: sqlRow[nameIdx].(string),
485-
fragment: sqlRow[fragmentIdx].(string),
486-
created: time.Unix(createdTime, 0).UTC(),
487-
sqlMode: sqlModeString,
505+
name: sqlRow[nameIdx].(string),
506+
schemaName: schemaNameString,
507+
fragment: sqlRow[fragmentIdx].(string),
508+
created: time.Unix(createdTime, 0).UTC(),
509+
sqlMode: sqlModeString,
488510
})
489511
}
490512

0 commit comments

Comments
 (0)