Skip to content

Commit be1b822

Browse files
authored
Merge pull request #6 from cockroachdb/nvanbenschoten/impr
Clean up ORM initlization and extend test cases
2 parents a8244ef + 37cd12b commit be1b822

File tree

5 files changed

+253
-42
lines changed

5 files changed

+253
-42
lines changed

go/gorm/model/models.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,24 @@ package model
22

33
// Customer is a model in the "customers" table.
44
type Customer struct {
5-
ID int
6-
Name *string `gorm:"not null"`
5+
ID int `json:"id"`
6+
Name *string `json:"name" gorm:"not null"`
77
}
88

99
// Order is a model in the "orders" table.
1010
type Order struct {
11-
ID int
12-
Subtotal float64 `gorm:"type:decimal(18,2)"`
11+
ID int `json:"id"`
12+
Subtotal float64 `json:"subtotal" gorm:"type:decimal(18,2)"`
1313

14-
Customer Customer `gorm:"ForeignKey:CustomerID"`
14+
Customer Customer `json:"customer" gorm:"ForeignKey:CustomerID"`
1515
CustomerID int `json:"-"`
1616

17-
Products []Product `gorm:"many2many:order_products"`
17+
Products []Product `json:"products" gorm:"many2many:order_products"`
1818
}
1919

2020
// Product is a model in the "products" table.
2121
type Product struct {
22-
ID int
23-
Name *string `gorm:"not null;unique"`
24-
Price float64 `gorm:"type:decimal(18,2)"`
22+
ID int `json:"id"`
23+
Name *string `json:"name" gorm:"not null;unique"`
24+
Price float64 `json:"price" gorm:"type:decimal(18,2)"`
2525
}

go/gorm/server.go

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ func NewServer(db *gorm.DB) *Server {
2222

2323
// RegisterRouter registers a router onto the Server.
2424
func (s *Server) RegisterRouter(router *httprouter.Router) {
25+
router.GET("/ping", s.ping)
26+
2527
router.GET("/customer", s.getCustomers)
2628
router.POST("/customer", s.createCustomer)
2729
router.GET("/customer/:customerID", s.getCustomer)
@@ -42,6 +44,10 @@ func (s *Server) RegisterRouter(router *httprouter.Router) {
4244
router.POST("/order/:orderID/product", s.addProductToOrder)
4345
}
4446

47+
func (s *Server) ping(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
48+
writeTextResult(w, "pong")
49+
}
50+
4551
func (s *Server) getCustomers(w http.ResponseWriter, r *http.Request, _ httprouter.Params) {
4652
var customers []model.Customer
4753
if err := s.db.Find(&customers).Error; err != nil {
@@ -176,6 +182,26 @@ func (s *Server) createOrder(w http.ResponseWriter, r *http.Request, ps httprout
176182
return
177183
}
178184

185+
if order.Customer.ID == 0 {
186+
http.Error(w, "must specify user", http.StatusBadRequest)
187+
return
188+
}
189+
if err := s.db.Find(&order.Customer, order.Customer.ID).Error; err != nil {
190+
http.Error(w, err.Error(), errToStatusCode(err))
191+
return
192+
}
193+
194+
for i, product := range order.Products {
195+
if product.ID == 0 {
196+
http.Error(w, "must specify a product ID", http.StatusBadRequest)
197+
return
198+
}
199+
if err := s.db.Find(&order.Products[i], product.ID).Error; err != nil {
200+
http.Error(w, err.Error(), errToStatusCode(err))
201+
return
202+
}
203+
}
204+
179205
if err := s.db.Create(&order).Error; err != nil {
180206
http.Error(w, err.Error(), errToStatusCode(err))
181207
} else {
@@ -260,14 +286,14 @@ func (s *Server) addProductToOrder(w http.ResponseWriter, r *http.Request, ps ht
260286
}
261287

262288
func writeTextResult(w http.ResponseWriter, res string) {
263-
w.WriteHeader(http.StatusOK)
264289
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
290+
w.WriteHeader(http.StatusOK)
265291
fmt.Fprintln(w, res)
266292
}
267293

268294
func writeJSONResult(w http.ResponseWriter, res interface{}) {
269-
w.WriteHeader(http.StatusOK)
270295
w.Header().Set("Content-Type", "application/json; charset=UTF-8")
296+
w.WriteHeader(http.StatusOK)
271297
if err := json.NewEncoder(w).Encode(res); err != nil {
272298
panic(err)
273299
}

testing/api_handler.go

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
const (
1212
applicationAddr = "http://localhost:6543"
1313

14+
pingPath = applicationAddr + "/ping"
1415
customersPath = applicationAddr + "/customer"
1516
ordersPath = applicationAddr + "/order"
1617
productsPath = applicationAddr + "/product"
@@ -23,19 +24,31 @@ const (
2324
// across all ORMs.
2425
type apiHandler struct{}
2526

27+
func (apiHandler) ping() error {
28+
_, err := http.Get(pingPath)
29+
return err
30+
}
31+
2632
func (apiHandler) queryCustomers() ([]model.Customer, error) {
2733
var customers []model.Customer
2834
if err := getJSON(customersPath, &customers); err != nil {
2935
return nil, err
3036
}
31-
return cleanCustomers(customers), nil
37+
return customers, nil
3238
}
3339
func (apiHandler) queryProducts() ([]model.Product, error) {
3440
var products []model.Product
3541
if err := getJSON(productsPath, &products); err != nil {
3642
return nil, err
3743
}
38-
return cleanProducts(products), nil
44+
return products, nil
45+
}
46+
func (apiHandler) queryOrders() ([]model.Order, error) {
47+
var orders []model.Order
48+
if err := getJSON(ordersPath, &orders); err != nil {
49+
return nil, err
50+
}
51+
return orders, nil
3952
}
4053

4154
func (apiHandler) createCustomer(name string) error {
@@ -46,6 +59,14 @@ func (apiHandler) createProduct(name string, price float64) error {
4659
product := model.Product{Name: &name, Price: price}
4760
return postJSONData(productsPath, product)
4861
}
62+
func (apiHandler) createOrder(customerID, productID int, subtotal float64) error {
63+
order := model.Order{
64+
Customer: model.Customer{ID: customerID},
65+
Products: []model.Product{{ID: productID}},
66+
Subtotal: subtotal,
67+
}
68+
return postJSONData(ordersPath, order)
69+
}
4970

5071
func getJSON(path string, result interface{}) error {
5172
resp, err := http.Get(path)
@@ -80,3 +101,12 @@ func cleanProducts(products []model.Product) []model.Product {
80101
}
81102
return products
82103
}
104+
func cleanOrders(orders []model.Order) []model.Order {
105+
for i := range orders {
106+
orders[i].ID = 0
107+
orders[i].Customer = model.Customer{}
108+
orders[i].CustomerID = 0
109+
orders[i].Products = nil
110+
}
111+
return orders
112+
}

testing/main_test.go

Lines changed: 98 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ import (
1818
_ "github.com/lib/pq"
1919
)
2020

21-
// application represents a single instance of a application running an ORM and
21+
// application represents a single instance of an application running an ORM and
2222
// exposing an HTTP REST API.
2323
type application struct {
2424
language string
@@ -33,6 +33,7 @@ func (app application) dbName() string {
3333
return fmt.Sprintf("company_%s", app.orm)
3434
}
3535

36+
// initTestDatabase launches a test database as a subprocess.
3637
func initTestDatabase(t *testing.T, app application) (*sql.DB, *url.URL, func()) {
3738
ts, err := testserver.NewTestServer()
3839
if err != nil {
@@ -56,21 +57,29 @@ func initTestDatabase(t *testing.T, app application) (*sql.DB, *url.URL, func())
5657

5758
ts.WaitForInit(db)
5859

60+
// Create the database if it does not exist.
5961
if _, err := db.Exec("CREATE DATABASE IF NOT EXISTS " + app.dbName()); err != nil {
6062
t.Fatal(err)
6163
}
62-
if _, err := db.Exec("SET DATABASE = " + app.dbName()); err != nil {
64+
65+
// Connect to the database again, now with the database in the URL.
66+
url.Path = app.dbName()
67+
db, err = sql.Open("postgres", url.String())
68+
if err != nil {
6369
t.Fatal(err)
6470
}
65-
url.Path = app.dbName()
6671

6772
return db, url, func() {
6873
_ = db.Close()
6974
ts.Stop()
7075
}
7176
}
7277

73-
func initORMApp(t *testing.T, app application, dbURL *url.URL) func() {
78+
type killFunc func()
79+
type restartFunc func() (killFunc, restartFunc)
80+
81+
// initORMApp launches an ORM application as a subprocess.
82+
func initORMApp(t *testing.T, app application, dbURL *url.URL) (killFunc, restartFunc) {
7483
addrFlag := fmt.Sprintf("ADDR=%s", dbURL.String())
7584
args := []string{"make", "start", "-C", app.dir(), addrFlag}
7685

@@ -79,27 +88,59 @@ func initORMApp(t *testing.T, app application, dbURL *url.URL) func() {
7988
// make will launch the application in a child process, and this is the most
8089
// straightforward way to kill all ancestors.
8190
cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
91+
killed := false
8292
killCmd := func() {
83-
syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
93+
if !killed {
94+
syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL)
95+
}
96+
killed = true
8497
}
8598

8699
// Set up stderr so we can later verify that it's clean.
87100
stderr := new(bytes.Buffer)
88101
cmd.Stderr = stderr
89102

90103
if err := cmd.Start(); err != nil {
104+
killCmd()
91105
t.Fatal(err)
92106
}
93107
if cmd.Process != nil {
94108
log.Printf("process %d started: %s", cmd.Process.Pid, strings.Join(args, " "))
95109
}
96110

97-
time.Sleep(3 * time.Second)
111+
if err := waitForInit(); err != nil {
112+
killCmd()
113+
t.Fatalf("error waiting for http server initialization: %v stderr=%s", err, stderr.String())
114+
}
98115
if s := stderr.String(); len(s) > 0 {
116+
killCmd()
99117
t.Fatalf("stderr=%s", s)
100118
}
101119

102-
return killCmd
120+
restartCmd := func() (killFunc, restartFunc) {
121+
killCmd()
122+
return initORMApp(t, app, dbURL)
123+
}
124+
125+
return killCmd, restartCmd
126+
}
127+
128+
// waitForInit retries until a connection is successfully established.
129+
func waitForInit() error {
130+
const maxWait = 15 * time.Second
131+
const waitDelay = 250 * time.Millisecond
132+
const maxWaitLoops = int(maxWait / waitDelay)
133+
134+
var err error
135+
var api apiHandler
136+
for i := 0; i < maxWaitLoops; i++ {
137+
if err = api.ping(); err == nil {
138+
return err
139+
}
140+
log.Printf("waitForInit: %v", err)
141+
time.Sleep(waitDelay)
142+
}
143+
return err
103144
}
104145

105146
func testORM(t *testing.T, language, orm string) {
@@ -111,7 +152,7 @@ func testORM(t *testing.T, language, orm string) {
111152
db, dbURL, stopDB := initTestDatabase(t, app)
112153
defer stopDB()
113154

114-
stopApp := initORMApp(t, app, dbURL)
155+
stopApp, restartApp := initORMApp(t, app, dbURL)
115156
defer stopApp()
116157

117158
td := testDriver{
@@ -120,27 +161,59 @@ func testORM(t *testing.T, language, orm string) {
120161
}
121162

122163
// Test that the correct tables were generated.
123-
t.Run("TestGeneratedTables", td.TestGeneratedTables)
164+
t.Run("GeneratedTables", td.TestGeneratedTables)
124165

125166
// Test that the correct columns in those tables were generated.
126-
t.Run("TestGeneratedCustomersTableColumns", td.TestGeneratedCustomersTableColumns)
127-
t.Run("TestGeneratedOrdersTableColumns", td.TestGeneratedOrdersTableColumns)
128-
t.Run("TestGeneratedProductsTableColumns", td.TestGeneratedProductsTableColumns)
129-
t.Run("TestGeneratedOrderProductsTableColumns", td.TestGeneratedOrderProductsTableColumns)
167+
t.Run("GeneratedColumns", parallelTestGroup{
168+
"CustomersTable": td.TestGeneratedCustomersTableColumns,
169+
"ProductsTable": td.TestGeneratedProductsTableColumns,
170+
"OrdersTable": td.TestGeneratedOrdersTableColumns,
171+
"OrderProductsTable": td.TestGeneratedOrderProductsTableColumns,
172+
}.T)
130173

131174
// Test that the tables begin empty.
132-
t.Run("TestOrdersTableEmpty", td.TestOrdersTableEmpty)
133-
t.Run("TestProductsTableEmpty", td.TestProductsTableEmpty)
134-
t.Run("TestCustomersEmpty", td.TestCustomersEmpty)
135-
t.Run("TestOrderProductsTableEmpty", td.TestOrderProductsTableEmpty)
136-
137-
// Test the creation of objects.
138-
t.Run("TestRetrieveCustomerBeforeCreation", td.TestRetrieveCustomerBeforeCreation)
139-
t.Run("TestRetrieveProductBeforeCreation", td.TestRetrieveProductBeforeCreation)
140-
t.Run("TestCreateCustomer", td.TestCreateCustomer)
141-
t.Run("TestCreateProduct", td.TestCreateProduct)
142-
t.Run("TestRetrieveCustomerAfterCreation", td.TestRetrieveCustomerAfterCreation)
143-
t.Run("TestRetrieveProductAfterCreation", td.TestRetrieveProductAfterCreation)
175+
t.Run("EmptyTables", parallelTestGroup{
176+
"CustomersTable": td.TestCustomersEmpty,
177+
"ProductsTable": td.TestProductsTableEmpty,
178+
"OrdersTable": td.TestOrdersTableEmpty,
179+
"OrderProductsTable": td.TestOrderProductsTableEmpty,
180+
}.T)
181+
182+
// Test that the API returns empty sets for each collection.
183+
t.Run("RetrieveFromAPIBeforeCreation", parallelTestGroup{
184+
"Customers": td.TestRetrieveCustomersBeforeCreation,
185+
"Products": td.TestRetrieveProductsBeforeCreation,
186+
"Orders": td.TestRetrieveOrdersBeforeCreation,
187+
}.T)
188+
189+
// Test the creation of initial objects.
190+
t.Run("CreateCustomer", td.TestCreateCustomer)
191+
t.Run("CreateProduct", td.TestCreateProduct)
192+
193+
// Test that the API returns what we just created.
194+
t.Run("RetrieveFromAPIAfterInitialCreation", parallelTestGroup{
195+
"Customers": td.TestRetrieveCustomerAfterCreation,
196+
"Products": td.TestRetrieveProductAfterCreation,
197+
}.T)
198+
199+
// Test the creation of dependent objects.
200+
t.Run("CreateOrder", td.TestCreateOrder)
201+
202+
// Test that the API returns what we just created.
203+
t.Run("RetrieveFromAPIAfterDependentCreation", parallelTestGroup{
204+
"Order": td.TestRetrieveProductAfterCreation,
205+
}.T)
206+
207+
// Restart the application.
208+
stopApp, restartApp = restartApp()
209+
defer stopApp()
210+
211+
// Test that the API still returns all created objects.
212+
t.Run("RetrieveFromAPIAfterRestart", parallelTestGroup{
213+
"Customers": td.TestRetrieveCustomerAfterCreation,
214+
"Products": td.TestRetrieveProductAfterCreation,
215+
"Order": td.TestRetrieveProductAfterCreation,
216+
}.T)
144217
}
145218

146219
func TestGORM(t *testing.T) {

0 commit comments

Comments
 (0)