Skip to content

Commit 1d8054a

Browse files
author
Alvaro Muñoz
committed
Refactor installCmd in cmd/install.go
1 parent 4ffacf0 commit 1d8054a

File tree

1 file changed

+55
-36
lines changed

1 file changed

+55
-36
lines changed

cmd/install.go

Lines changed: 55 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,34 @@
11
package cmd
22

33
import (
4-
"io"
54
"errors"
65
"fmt"
7-
"io/ioutil"
6+
"io"
87
"log"
98
"os"
109
"path/filepath"
1110
"strings"
1211

13-
"github.com/spf13/cobra"
14-
"github.com/GitHubSecurityLab/gh-qldb/utils"
12+
"github.com/GitHubSecurityLab/gh-qldb/utils"
13+
"github.com/spf13/cobra"
1514
)
1615

1716
var installCmd = &cobra.Command{
18-
Use: "install",
19-
Short: "Install a local CodeQL database in the QLDB directory",
20-
Long: `Install a local CodeQL database in the QLDB directory`,
21-
Run: func(cmd *cobra.Command, args []string) {
22-
install(nwoFlag, dbPathFlag, removeFlag)
23-
},
24-
}
17+
Use: "install",
18+
Short: "Install a local CodeQL database in the QLDB directory",
19+
Long: `Install a local CodeQL database in the QLDB directory`,
20+
Run: func(cmd *cobra.Command, args []string) {
21+
install(nwoFlag, dbPathFlag, removeFlag)
22+
},
23+
}
2524

2625
func init() {
27-
rootCmd.AddCommand(installCmd)
28-
installCmd.Flags().StringVarP(&nwoFlag, "nwo", "n", "", "The NWO to associate the database to.")
29-
installCmd.Flags().StringVarP(&dbPathFlag, "database", "d", "", "The path to the database to install.")
30-
installCmd.Flags().BoolVarP(&removeFlag, "remove", "r", false, "Remove the database after installing it.")
31-
installCmd.MarkFlagRequired("nwo")
32-
installCmd.MarkFlagRequired("database")
26+
rootCmd.AddCommand(installCmd)
27+
installCmd.Flags().StringVarP(&nwoFlag, "nwo", "n", "", "The NWO to associate the database to.")
28+
installCmd.Flags().StringVarP(&dbPathFlag, "database", "d", "", "The path to the database to install.")
29+
installCmd.Flags().BoolVarP(&removeFlag, "remove", "r", false, "Remove the database after installing it.")
30+
installCmd.MarkFlagRequired("nwo")
31+
installCmd.MarkFlagRequired("database")
3332
}
3433

3534
func install(nwo string, dbPath string, remove bool) {
@@ -42,13 +41,14 @@ func install(nwo string, dbPath string, remove bool) {
4241
log.Fatal(errors.New("DB path does not exist"))
4342
}
4443
if fileinfo.IsDir() {
44+
fmt.Printf("Validating %s DB\n", dbPath)
4545
err := utils.ValidateDB(dbPath)
4646
if err != nil {
4747
fmt.Println("DB is not valid")
4848
}
4949
// Compress DB
5050
zipfilename := filepath.Join(os.TempDir(), "qldb.zip")
51-
fmt.Println("Zipping DB to", zipfilename)
51+
fmt.Println("Compressing DB to", zipfilename)
5252
if err := utils.ZipDirectory(zipfilename, dbPath); err != nil {
5353
log.Fatal(err)
5454
}
@@ -61,19 +61,24 @@ func install(nwo string, dbPath string, remove bool) {
6161
}
6262

6363
zipPath = dbPath
64-
// Unzip to temporary directory
65-
tmpdir, _ := ioutil.TempDir("", "qldb")
64+
// Unzip to a temporary directory
65+
tmpdir, _ := os.MkdirTemp("", "qldb")
66+
6667
_, err := utils.Unzip(dbPath, tmpdir)
6768
if err != nil {
6869
log.Fatal(err)
6970
}
70-
files, err := ioutil.ReadDir(tmpdir)
71+
72+
// Read all files in the tmpdir directory using os.ReadDir
73+
dirEntries, err := os.ReadDir(tmpdir)
7174
if err != nil {
7275
log.Fatal(err)
7376
}
74-
if len(files) == 1 {
75-
tmpdir = filepath.Join(tmpdir, files[0].Name())
77+
if len(dirEntries) == 1 {
78+
// if there is one directory in the tmpdir, use that as the tmpdir
79+
tmpdir = filepath.Join(tmpdir, dirEntries[0].Name())
7680
}
81+
fmt.Printf("Validating %s DB\n", tmpdir)
7782
err = utils.ValidateDB(tmpdir)
7883
if err != nil {
7984
fmt.Println("DB is not valid")
@@ -86,44 +91,58 @@ func install(nwo string, dbPath string, remove bool) {
8691
log.Fatal(err)
8792
}
8893
defer zipFile.Close()
89-
zipBytes, err := ioutil.ReadAll(zipFile)
94+
zipBytes, err := io.ReadAll(zipFile)
9095
if err != nil {
9196
log.Fatal(err)
9297
}
9398
commitSha, primaryLanguage, err := utils.ExtractDBInfo(zipBytes)
99+
shortCommitSha := commitSha[:8]
100+
fmt.Println("Commit SHA:", commitSha)
101+
fmt.Println("Short Commit SHA:", shortCommitSha)
102+
fmt.Println("Primary language:", primaryLanguage)
94103

95104
// Destination path
96-
dir := filepath.Join(utils.GetPath(nwo), primaryLanguage)
97-
filename := fmt.Sprintf("%s.zip", commitSha)
98-
path := filepath.Join(dir, filename)
99-
fmt.Println("Installing DB to", path)
105+
filename := fmt.Sprintf("%s-%s.zip", primaryLanguage, shortCommitSha)
106+
destPath := filepath.Join(utils.GetPath(nwo), filename)
107+
fmt.Println("Installing DB to", destPath)
100108

101109
// Check if the DB is already installed
102-
if _, err := os.Stat(path); errors.Is(err, os.ErrNotExist) {
103-
// Copy DB to the right place
104-
srcFile, err := os.Open(zipPath)
110+
if _, err := os.Stat(destPath); errors.Is(err, os.ErrNotExist) {
111+
112+
// Create the directory if it doesn't exist
113+
err = os.MkdirAll(filepath.Dir(destPath), 0755)
105114
if err != nil {
106115
log.Fatal(err)
116+
return
107117
}
108-
defer srcFile.Close()
109-
err = os.MkdirAll(filepath.Dir(path), 0755)
118+
119+
// Copy file from zipPath to destPath
120+
srcFile, err := os.Open(zipPath)
110121
if err != nil {
111122
log.Fatal(err)
123+
return
112124
}
113-
destFile, err := os.Create(path)
125+
defer srcFile.Close()
126+
127+
destFile, err := os.Create(destPath)
114128
if err != nil {
115129
log.Fatal(err)
130+
return
116131
}
117132
defer destFile.Close()
118-
fmt.Println("Copying DB to", path)
119-
_, err = io.Copy(srcFile, destFile) // check first var for number of bytes copied
133+
134+
bytes, err := io.Copy(destFile, srcFile)
135+
fmt.Println(fmt.Sprintf("Copied %d bytes", bytes))
136+
120137
if err != nil {
121138
log.Fatal(err)
122139
}
123140
err = destFile.Sync()
124141
if err != nil {
125142
log.Fatal(err)
126143
}
144+
} else {
145+
fmt.Println("DB already installed for same commit")
127146
}
128147
// Remove DB from the current location if -r flag is set
129148
if remove {

0 commit comments

Comments
 (0)