@@ -2,6 +2,7 @@ package sqlite
22
33import (
44 "context"
5+ "database/sql"
56 "encoding/json"
67 "fmt"
78 "strings"
@@ -1093,14 +1094,50 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) {
10931094 },
10941095 },
10951096 },
1097+
1098+ {
1099+ description : "PersistDefaultChannel" ,
1100+ fields : fields {
1101+ bundles : []* registry.Bundle {
1102+ newBundle (t , "csv-a" , "pkg-0" , []string {"a" }, newUnstructuredCSV (t , "csv-a" , "" )),
1103+ newBundle (t , "csv-b" , "pkg-0" , []string {"b" }, newUnstructuredCSV (t , "csv-b" , "" )),
1104+ },
1105+ pkgs : []registry.PackageManifest {
1106+ {
1107+ PackageName : "pkg-0" ,
1108+ Channels : []registry.PackageChannel {
1109+ {
1110+ Name : "a" ,
1111+ CurrentCSVName : "csv-a" ,
1112+ },
1113+ {
1114+ Name : "b" ,
1115+ CurrentCSVName : "csv-b" ,
1116+ },
1117+ },
1118+ DefaultChannelName : "a" ,
1119+ },
1120+ },
1121+ },
1122+ args : args {
1123+ bundle : "csv-a" ,
1124+ pkg : "pkg-0" ,
1125+ },
1126+ expected : expected {
1127+ err : nil ,
1128+ bundles : map [string ]struct {}{
1129+ "pkg-0/b/csv-b" : {},
1130+ },
1131+ },
1132+ },
10961133 }
10971134 for _ , tt := range tests {
10981135 t .Run (tt .description , func (t * testing.T ) {
10991136 db , cleanup := CreateTestDb (t )
11001137 defer cleanup ()
11011138 store , err := NewSQLLiteLoader (db )
11021139 require .NoError (t , err )
1103- err = store .Migrate (context .TODO ())
1140+ err = store .Migrate (context .Background ())
11041141 require .NoError (t , err )
11051142
11061143 for _ , bundle := range tt .fields .bundles {
@@ -1112,6 +1149,21 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) {
11121149 // Throw away any errors loading packages (not testing this)
11131150 store .AddPackageChannels (pkg )
11141151 }
1152+
1153+ getDefaultChannel := func (pkg string ) sql.NullString {
1154+ // get defaultChannel before delete
1155+ rows , err := db .QueryContext (context .Background (), `SELECT default_channel FROM package WHERE name = ?` , pkg )
1156+ require .NoError (t , err )
1157+ defer rows .Close ()
1158+ var defaultChannel sql.NullString
1159+ for rows .Next () {
1160+ require .NoError (t , rows .Scan (& defaultChannel ))
1161+ break
1162+ }
1163+ return defaultChannel
1164+ }
1165+ oldDefaultChannel := getDefaultChannel (tt .args .pkg )
1166+
11151167 err = store .(registry.HeadOverwriter ).RemoveOverwrittenChannelHead (tt .args .pkg , tt .args .bundle )
11161168 if tt .expected .err != nil {
11171169 require .EqualError (t , err , tt .expected .err .Error ())
@@ -1121,7 +1173,7 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) {
11211173
11221174 querier := NewSQLLiteQuerierFromDb (db )
11231175
1124- bundles , err := querier .ListBundles (context .TODO ())
1176+ bundles , err := querier .ListBundles (context .Background ())
11251177 require .NoError (t , err )
11261178
11271179 var extra []string
@@ -1141,6 +1193,9 @@ func TestRemoveOverwrittenChannelHead(t *testing.T) {
11411193 t .Errorf ("unexpected bundles found: %v" , extra )
11421194 }
11431195
1196+ // should preserve defaultChannel entry in package table
1197+ currentDefaultChannel := getDefaultChannel (tt .args .pkg )
1198+ require .Equal (t , oldDefaultChannel , currentDefaultChannel )
11441199 })
11451200 }
11461201}
0 commit comments