Skip to content

Commit 9aa50b6

Browse files
committed
Add schema name to views for doltgres
1 parent b9f909d commit 9aa50b6

File tree

4 files changed

+111
-31
lines changed

4 files changed

+111
-31
lines changed

go/libraries/doltcore/doltdb/system_table.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ const (
256256
// SchemasTablesSqlModeCol is the name of the column that stores the SQL_MODE string used when this fragment
257257
// was originally defined. Mode settings, such as ANSI_QUOTES, are needed to correctly parse the fragment.
258258
SchemasTablesSqlModeCol = "sql_mode"
259+
// SchemasTablesSchemaNameCol is the name of the column that stores the name of the schema that the fragment
260+
// is part of. Used by Doltgres only.
261+
SchemasTablesSchemaNameCol = "schema_name"
259262
)
260263

261264
const (

go/libraries/doltcore/schema/reserved_tags.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ const (
8080
DoltSchemasFragmentTag
8181
DoltSchemasExtraTag
8282
DoltSchemasSqlModeTag
83+
DoltSchemasSchemaNameTag
8384
)
8485

8586
// Tags for hidden columns in keyless rows

go/libraries/doltcore/sqle/database.go

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,6 +1729,7 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie
17291729
}
17301730
}
17311731

1732+
schemaName := db.schemaName
17321733
lwrViewName := strings.ToLower(viewName)
17331734
switch {
17341735
case strings.HasPrefix(lwrViewName, doltdb.DoltBlameViewPrefix):
@@ -1738,7 +1739,12 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie
17381739
if err != nil {
17391740
return sql.ViewDefinition{}, false, err
17401741
}
1741-
return sql.ViewDefinition{Name: viewName, TextDefinition: blameViewTextDef, CreateViewStatement: fmt.Sprintf("CREATE VIEW `%s` AS %s", viewName, blameViewTextDef)}, true, nil
1742+
return sql.ViewDefinition{
1743+
Name: viewName,
1744+
SchemaName: db.schemaName,
1745+
TextDefinition: blameViewTextDef,
1746+
CreateViewStatement: fmt.Sprintf("CREATE VIEW `%s` AS %s", viewName, blameViewTextDef)},
1747+
true, nil
17421748
}
17431749

17441750
schemasTableName := getDoltSchemasTableName()
@@ -1780,22 +1786,22 @@ func (db Database) GetViewDefinition(ctx *sql.Context, viewName string) (sql.Vie
17801786
}
17811787

17821788
if wrapper.backingTable == nil {
1783-
dbState.SessionCache().CacheViews(key, nil, db.schemaName)
1789+
dbState.SessionCache().CacheViews(key, nil, schemaName)
17841790
return sql.ViewDefinition{}, false, nil
17851791
}
17861792

1787-
views, viewDef, found, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, viewName)
1793+
views, viewDef, found, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, viewName, schemaName)
17881794
if err != nil {
17891795
return sql.ViewDefinition{}, false, err
17901796
}
17911797

17921798
// TODO: only cache views from a single schema here
1793-
dbState.SessionCache().CacheViews(key, views, db.schemaName)
1799+
dbState.SessionCache().CacheViews(key, views, schemaName)
17941800

17951801
return viewDef, found, nil
17961802
}
17971803

1798-
func getViewDefinitionFromSchemaFragmentsOfView(ctx *sql.Context, tbl *WritableDoltTable, viewName string) ([]sql.ViewDefinition, sql.ViewDefinition, bool, error) {
1804+
func getViewDefinitionFromSchemaFragmentsOfView(ctx *sql.Context, tbl *WritableDoltTable, viewName, schemaName string) ([]sql.ViewDefinition, sql.ViewDefinition, bool, error) {
17991805
fragments, err := getSchemaFragmentsOfType(ctx, tbl, viewFragment)
18001806
if err != nil {
18011807
return nil, sql.ViewDefinition{}, false, err
@@ -1814,14 +1820,15 @@ func getViewDefinitionFromSchemaFragmentsOfView(ctx *sql.Context, tbl *WritableD
18141820
}
18151821
} else {
18161822
views[i] = sql.ViewDefinition{
1817-
Name: fragments[i].name,
1823+
Name: fragments[i].name,
1824+
SchemaName: fragments[i].schemaName,
18181825
// TODO: need to define TextDefinition
18191826
CreateViewStatement: fragments[i].fragment,
18201827
SqlMode: fragment.sqlMode,
18211828
}
18221829
}
18231830

1824-
if strings.EqualFold(fragment.name, viewName) {
1831+
if strings.EqualFold(fragment.name, viewName) && strings.EqualFold(fragment.schemaName, schemaName) {
18251832
found = true
18261833
viewDef = views[i]
18271834
}
@@ -1849,7 +1856,7 @@ func (db Database) AllViews(ctx *sql.Context) ([]sql.ViewDefinition, error) {
18491856
return nil, nil
18501857
}
18511858

1852-
views, _, _, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, "")
1859+
views, _, _, err := getViewDefinitionFromSchemaFragmentsOfView(ctx, wrapper.backingTable, "", "")
18531860
if err != nil {
18541861
return nil, err
18551862
}
@@ -1945,7 +1952,7 @@ func (db Database) CreateTrigger(ctx *sql.Context, definition sql.TriggerDefinit
19451952
definition.Name,
19461953
definition.CreateStatement,
19471954
definition.CreatedAt,
1948-
fmt.Errorf("triggers `%s` already exists", definition.Name), //TODO: add a sql error and return that instead
1955+
fmt.Errorf("triggers `%s` already exists", definition.Name), // TODO: add a sql error and return that instead
19491956
)
19501957
}
19511958

@@ -2207,7 +2214,7 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
22072214
return err
22082215
}
22092216

2210-
_, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name)
2217+
_, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name, db.schemaName)
22112218
if err != nil {
22122219
return err
22132220
}
@@ -2234,14 +2241,34 @@ func (db Database) addFragToSchemasTable(ctx *sql.Context, fragType, name, defin
22342241

22352242
sqlMode := sql.LoadSqlMode(ctx)
22362243

2237-
return inserter.Insert(ctx, sql.Row{fragType, name, definition, extraJSON, sqlMode.String()})
2244+
row := sql.Row{fragType, name, definition, extraJSON, sqlMode.String()}
2245+
2246+
// Include schema_name column for doltgres
2247+
if resolve.UseSearchPath && tbl.Schema().Contains(doltdb.SchemasTablesSchemaNameCol, tbl.Name()) {
2248+
if db.schemaName == "" {
2249+
root, err := db.GetRoot(ctx)
2250+
if err != nil {
2251+
return err
2252+
}
2253+
schemaName, err := resolve.FirstExistingSchemaOnSearchPath(ctx, root)
2254+
if err != nil {
2255+
return err
2256+
}
2257+
db.schemaName = schemaName
2258+
}
2259+
2260+
row = sql.Row{fragType, name, db.schemaName, definition, extraJSON, sqlMode.String()}
2261+
}
2262+
2263+
return inserter.Insert(ctx, row)
22382264
}
22392265

22402266
func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name string, missingErr error) error {
22412267
if err := dsess.CheckAccessForDb(ctx, db, branch_control.Permissions_Write); err != nil {
22422268
return err
22432269
}
22442270

2271+
schemaName := db.schemaName
22452272
if resolve.UseSearchPath {
22462273
db.schemaName = "dolt"
22472274
}
@@ -2260,7 +2287,7 @@ func (db Database) dropFragFromSchemasTable(ctx *sql.Context, fragType, name str
22602287
}
22612288

22622289
tbl := swrapper.backingTable
2263-
row, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name)
2290+
row, exists, err := fragFromSchemasTable(ctx, tbl, fragType, name, schemaName)
22642291
if err != nil {
22652292
return err
22662293
}

go/libraries/doltcore/sqle/schema_table.go

Lines changed: 68 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,34 @@ func (st *SchemaTable) String() string {
5555
return doltdb.GetSchemasTableName()
5656
}
5757

58+
// GetSchemasSchema returns the schema of the dolt_schemas system table. This is used
59+
// by Doltgres to update the dolt_schemas schema with an additional schema_name column.
60+
var GetSchemasSchema = SchemaTableSchema
61+
62+
// func getDoltSchemasSchema(backingTable *WritableDoltTable) sql.Schema {
63+
// if backingTable == nil {
64+
// // No backing table; return a current schema.
65+
// return GetSchemasSchema().Schema
66+
// }
67+
68+
// if !backingTable.Schema().Contains(doltdb.SchemasTablesExtraCol, doltdb.GetSchemasTableName()) {
69+
// // No Extra column; return an ancient schema.
70+
// return SchemaTableAncientSqlSchema()
71+
// }
72+
73+
// if !backingTable.Schema().Contains(doltdb.SchemasTablesSqlModeCol, doltdb.GetSchemasTableName()) {
74+
// // No SQL_MODE column; return an old schema.
75+
// return SchemaTableV1SqlSchema()
76+
// }
77+
78+
// return GetSchemasSchema().Schema
79+
// }
80+
5881
func (st *SchemaTable) Schema() sql.Schema {
82+
currentSchema := toSqlSchemaTableSchema(GetSchemasSchema())
5983
if st.backingTable == nil {
6084
// No backing table; return a current schema.
61-
return SchemaTableSqlSchema().Schema
85+
return currentSchema.Schema
6286
}
6387

6488
if !st.backingTable.Schema().Contains(doltdb.SchemasTablesExtraCol, doltdb.GetSchemasTableName()) {
@@ -71,7 +95,7 @@ func (st *SchemaTable) Schema() sql.Schema {
7195
return SchemaTableV1SqlSchema()
7296
}
7397

74-
return SchemaTableSqlSchema().Schema
98+
return currentSchema.Schema
7599
}
76100

77101
func (st *SchemaTable) Collation() sql.CollationID {
@@ -127,14 +151,22 @@ var _ sql.IndexAddressableTable = (*SchemaTable)(nil)
127151
var _ sql.UpdatableTable = (*SchemaTable)(nil)
128152
var _ WritableDoltTableWrapper = (*SchemaTable)(nil)
129153

130-
func SchemaTableSqlSchema() sql.PrimaryKeySchema {
131-
sqlSchema, err := sqlutil.FromDoltSchema("", doltdb.GetSchemasTableName(), SchemaTableSchema())
154+
func toSqlSchemaTableSchema(sch schema.Schema) sql.PrimaryKeySchema {
155+
sqlSchema, err := sqlutil.FromDoltSchema("", doltdb.GetSchemasTableName(), sch)
132156
if err != nil {
133157
panic(err) // should never happen
134158
}
135159
return sqlSchema
136160
}
137161

162+
// func SchemaTableSqlSchema() sql.PrimaryKeySchema {
163+
// sqlSchema, err := sqlutil.FromDoltSchema("", doltdb.GetSchemasTableName(), SchemaTableSchema())
164+
// if err != nil {
165+
// panic(err) // should never happen
166+
// }
167+
// return sqlSchema
168+
// }
169+
138170
func mustNewColWithTypeInfo(name string, tag uint64, typeInfo typeinfo.TypeInfo, partOfPK bool, defaultVal string, autoIncrement bool, comment string, constraints ...schema.ColConstraint) schema.Column {
139171
col, err := schema.NewColumnWithTypeInfo(name, tag, typeInfo, partOfPK, defaultVal, autoIncrement, comment, constraints...)
140172
if err != nil {
@@ -250,7 +282,7 @@ func getOrCreateDoltSchemasTable(ctx *sql.Context, db Database) (retTbl *Writabl
250282
}
251283

252284
// Create new empty table
253-
err = db.createDoltTable(ctx, tname, root, SchemaTableSchema())
285+
err = db.createDoltTable(ctx, tname, root, GetSchemasSchema())
254286
if err != nil {
255287
return nil, err
256288
}
@@ -367,8 +399,8 @@ func migrateOldSchemasTableToNew(ctx *sql.Context, db Database, schemasTable *Wr
367399
}
368400

369401
// fragFromSchemasTable returns the row with the given schema fragment if it exists.
370-
func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType string, name string) (r sql.Row, found bool, rerr error) {
371-
fragType, name = strings.ToLower(fragType), strings.ToLower(name)
402+
func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType, name, schemaName string) (r sql.Row, found bool, rerr error) {
403+
fragType, name, schemaName = strings.ToLower(fragType), strings.ToLower(name), strings.ToLower(schemaName)
372404

373405
// This performs a full table scan in the worst case, but it's only used when adding or dropping a trigger or view
374406
iter, err := SqlTableToRowIter(ctx, tbl.DoltTable, nil)
@@ -387,6 +419,7 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
387419
// need to get the column indexes from the current schema
388420
nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
389421
typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
422+
schemaNameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesSchemaNameCol)
390423

391424
for {
392425
sqlRow, err := iter.Next(ctx)
@@ -397,8 +430,13 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
397430
return nil, false, err
398431
}
399432

433+
sqlRowSchemaName := ""
434+
if schemaNameIdx >= 0 {
435+
sqlRowSchemaName = sqlRow[schemaNameIdx].(string)
436+
}
437+
400438
// These columns are case insensitive, make sure to do a case-insensitive comparison
401-
if strings.EqualFold(sqlRow[typeIdx].(string), fragType) && strings.EqualFold(sqlRow[nameIdx].(string), name) {
439+
if strings.EqualFold(sqlRow[typeIdx].(string), fragType) && strings.EqualFold(sqlRow[nameIdx].(string), name) && strings.EqualFold(sqlRowSchemaName, schemaName) {
402440
return sqlRow, true, nil
403441
}
404442
}
@@ -407,9 +445,10 @@ func fragFromSchemasTable(ctx *sql.Context, tbl *WritableDoltTable, fragType str
407445
}
408446

409447
type schemaFragment struct {
410-
name string
411-
fragment string
412-
created time.Time
448+
name string
449+
schemaName string
450+
fragment string
451+
created time.Time
413452
// sqlMode indicates the SQL_MODE that was used when this schema fragment was initially parsed. SQL_MODE settings
414453
// such as ANSI_QUOTES control customized parsing behavior needed for some schema fragments.
415454
sqlMode string
@@ -424,6 +463,7 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
424463
// The dolt_schemas table has undergone various changes over time and multiple possible schemas for it exist, so we
425464
// need to get the column indexes from the current schema
426465
nameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesNameCol)
466+
schemaNameIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesSchemaNameCol)
427467
typeIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesTypeCol)
428468
fragmentIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesFragmentCol)
429469
extraIdx := tbl.sqlSchema().IndexOfColName(doltdb.SchemasTablesExtraCol)
@@ -463,13 +503,21 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
463503
sqlModeString = defaultSqlMode
464504
}
465505

506+
schemaNameString := ""
507+
if schemaNameIdx >= 0 {
508+
if s, ok := sqlRow[schemaNameIdx].(string); ok {
509+
schemaNameString = s
510+
}
511+
}
512+
466513
// For older tables, use 1 as the trigger creation time
467514
if extraIdx < 0 || sqlRow[extraIdx] == nil {
468515
frags = append(frags, schemaFragment{
469-
name: sqlRow[nameIdx].(string),
470-
fragment: sqlRow[fragmentIdx].(string),
471-
created: time.Unix(1, 0).UTC(), // TablePlus editor thinks 0 is out of range
472-
sqlMode: sqlModeString,
516+
name: sqlRow[nameIdx].(string),
517+
schemaName: schemaNameString,
518+
fragment: sqlRow[fragmentIdx].(string),
519+
created: time.Unix(1, 0).UTC(), // TablePlus editor thinks 0 is out of range
520+
sqlMode: sqlModeString,
473521
})
474522
continue
475523
}
@@ -481,10 +529,11 @@ func getSchemaFragmentsOfType(ctx *sql.Context, tbl *WritableDoltTable, fragType
481529
}
482530

483531
frags = append(frags, schemaFragment{
484-
name: sqlRow[nameIdx].(string),
485-
fragment: sqlRow[fragmentIdx].(string),
486-
created: time.Unix(createdTime, 0).UTC(),
487-
sqlMode: sqlModeString,
532+
name: sqlRow[nameIdx].(string),
533+
schemaName: schemaNameString,
534+
fragment: sqlRow[fragmentIdx].(string),
535+
created: time.Unix(createdTime, 0).UTC(),
536+
sqlMode: sqlModeString,
488537
})
489538
}
490539

0 commit comments

Comments
 (0)