Skip to content

Commit 2262504

Browse files
authored
Merge pull request #55 from linuxkit/shard-option
add sharding ability
2 parents 2351267 + 70433ba commit 2262504

File tree

8 files changed

+313
-60
lines changed

8 files changed

+313
-60
lines changed

cmd/list.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,18 @@ var listCmd = &cobra.Command{
3131
}
3232

3333
func init() {
34+
flags := listCmd.Flags()
35+
// shardPattern is 1-based (1/10, 3/10, 10/10) rather than normal computer 0-based (0/9, 2/9, 9/9), because it is easier for
36+
// humans to understand when calling the CLI.
37+
flags.StringVarP(&shardPattern, "shard", "s", "", "which shard to run, in form of 'N/M' where N is the shard number and M is the total number of shards, smallest shard number is 1")
3438
RootCmd.AddCommand(listCmd)
3539
}
3640

3741
func list(_ *cobra.Command, args []string) error {
42+
shard, totalShards, err := parseShardPattern(shardPattern)
43+
if err != nil {
44+
return err
45+
}
3846
pattern, err := local.ValidatePattern(args)
3947
if err != nil {
4048
return err
@@ -45,6 +53,11 @@ func list(_ *cobra.Command, args []string) error {
4553
if err != nil {
4654
return err
4755
}
56+
if totalShards > 0 {
57+
if err := p.SetShard(shard, totalShards); err != nil {
58+
return err
59+
}
60+
}
4861

4962
w := new(tabwriter.Writer)
5063
w.Init(os.Stdout, 0, 8, 0, '\t', 0)

cmd/run.go

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,12 @@ var (
7171
)
7272

7373
var (
74-
resultDir string
75-
id string
76-
symlink bool
77-
extra bool
78-
parallel bool
74+
resultDir string
75+
id string
76+
symlink bool
77+
extra bool
78+
parallel bool
79+
shardPattern string
7980
)
8081

8182
var runCmd = &cobra.Command{
@@ -90,10 +91,17 @@ func init() {
9091
flags.StringVarP(&id, "id", "", "", "ID for this test run")
9192
flags.BoolVarP(&extra, "extra", "x", false, "Add extra debug info to log files")
9293
flags.BoolVarP(&parallel, "parallel", "p", false, "Run multiple tests in parallel")
94+
// shardPattern is 1-based (1/10, 3/10, 10/10) rather than normal computer 0-based (0/9, 2/9, 9/9), because it is easier for
95+
// humans to understand when calling the CLI.
96+
flags.StringVarP(&shardPattern, "shard", "s", "", "which shard to run, in form of 'N/M' where N is the shard number and M is the total number of shards, smallest shard number is 1")
9397
RootCmd.AddCommand(runCmd)
9498
}
9599

96100
func run(cmd *cobra.Command, args []string) error {
101+
shard, totalShards, err := parseShardPattern(shardPattern)
102+
if err != nil {
103+
return err
104+
}
97105
pattern, err := local.ValidatePattern(args)
98106
if err != nil {
99107
return err
@@ -106,6 +114,11 @@ func run(cmd *cobra.Command, args []string) error {
106114
if err != nil {
107115
return err
108116
}
117+
if totalShards > 0 {
118+
if err := p.SetShard(shard, totalShards); err != nil {
119+
return err
120+
}
121+
}
109122

110123
var labelList []string
111124
for k := range runConfig.Labels {
@@ -309,3 +322,23 @@ func setupResultsDirectory(id string, link bool) (string, error) {
309322

310323
return baseDir, nil
311324
}
325+
326+
func parseShardPattern(pattern string) (shard int, total int, err error) {
327+
if pattern == "" {
328+
return 0, 0, nil
329+
}
330+
parts := strings.SplitN(pattern, "/", 2)
331+
if len(parts) != 2 {
332+
return 0, 0, fmt.Errorf("invalid shard pattern: %s", pattern)
333+
}
334+
if shard, err = strconv.Atoi(parts[0]); err != nil {
335+
return 0, 0, fmt.Errorf("invalid shard pattern: %s", pattern)
336+
}
337+
if total, err = strconv.Atoi(parts[1]); err != nil {
338+
return 0, 0, fmt.Errorf("invalid shard pattern: %s", pattern)
339+
}
340+
if shard < 1 || total < 1 || shard > total {
341+
return 0, 0, fmt.Errorf("invalid shard pattern: %s", pattern)
342+
}
343+
return shard, total, nil
344+
}

local/group.go

Lines changed: 88 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -9,32 +9,8 @@ import (
99
"strings"
1010
"sync"
1111
"time"
12-
13-
"github.com/linuxkit/rtf/logger"
1412
)
1513

16-
// NewProject creates a new top-level Group at the provided path
17-
func NewProject(path string) (*Group, error) {
18-
if !filepath.IsAbs(path) {
19-
var err error
20-
path, err = filepath.Abs(path)
21-
if err != nil {
22-
return nil, err
23-
}
24-
}
25-
g := &Group{Parent: nil, Path: path}
26-
return g, nil
27-
}
28-
29-
// InitNewProject creates a new Group, and calls Init() on it
30-
func InitNewProject(path string) (*Group, error) {
31-
group, err := NewProject(path)
32-
if err != nil {
33-
return group, err
34-
}
35-
return group, group.Init()
36-
}
37-
3814
// NewGroup creates a new Group with the given parent and path
3915
func NewGroup(parent *Group, path string) (*Group, error) {
4016
g := &Group{Parent: parent, Path: path, PreTestPath: parent.PreTestPath, PostTestPath: parent.PostTestPath}
@@ -151,44 +127,78 @@ func (g *Group) List(config RunConfig) []Info {
151127
return infos
152128
}
153129

130+
// Gather gathers all runnable child groups and tests
131+
func (g *Group) Gather(config RunConfig, count int) ([]TestContainer, int) {
132+
sort.Sort(ByOrder(g.Children))
133+
134+
if !g.willRun(config) {
135+
return nil, 0
136+
}
137+
containers := []TestContainer{}
138+
var subCount int
139+
140+
if g.GroupFilePath != "" {
141+
containers = append(containers, GroupCommand{Name: g.Name(), FilePath: g.GroupFilePath, Path: g.Path, Type: "init"})
142+
}
143+
144+
for _, c := range g.Children {
145+
lst, childCount := c.Gather(config, count+subCount)
146+
// if we had no runnable tests, do not bother adding the group init/deinit, just return the empty list
147+
if childCount == 0 {
148+
continue
149+
}
150+
containers = append(containers, lst...)
151+
subCount += childCount
152+
}
153+
154+
if g.GroupFilePath != "" {
155+
containers = append(containers, GroupCommand{Name: g.Name(), FilePath: g.GroupFilePath, Path: g.Path, Type: "deinit"})
156+
}
157+
158+
return containers, subCount
159+
}
160+
154161
// Run will run all child groups and tests
155162
func (g *Group) Run(config RunConfig) ([]Result, error) {
156163
var results []Result
157-
sort.Sort(ByOrder(g.Children))
158164

159-
if !g.willRun(config) {
165+
// This gathers all of the individual tests and group init/deinit commands
166+
// all the way down, leading to a flat list we can execute, rather than recursion.
167+
// That should make it easier to break into shards.
168+
count := 0
169+
runnables, _ := g.Gather(config, count)
170+
if len(runnables) == 0 {
160171
return []Result{{TestResult: Skip,
161172
Name: g.Name(),
162173
StartTime: time.Now(),
163174
EndTime: time.Now(),
164175
}}, nil
165176
}
166177

167-
if g.GroupFilePath != "" {
168-
config.Logger.Log(logger.LevelInfo, fmt.Sprintf("%s::ginit()", g.Name()))
169-
res, err := executeScript(g.GroupFilePath, g.Path, "", []string{"init"}, config)
170-
if err != nil {
171-
return results, err
172-
}
173-
if res.TestResult != Pass {
174-
return results, fmt.Errorf("Error running %s", g.GroupFilePath+":init")
175-
}
176-
}
177-
178178
if config.Parallel {
179179
var wg sync.WaitGroup
180180
resCh := make(chan []Result, len(g.Children))
181181
errCh := make(chan error, len(g.Children))
182182

183-
for _, c := range g.Children {
183+
for _, c := range runnables {
184184
wg.Add(1)
185185
go func(c TestContainer, cf RunConfig) {
186186
defer wg.Done()
187-
res, err := c.Run(cf)
188-
if err != nil {
189-
errCh <- err
187+
var isTest bool
188+
if _, ok := c.(*Test); ok {
189+
isTest = true
190+
}
191+
// Run() if one of the following is true: it is not a test; there are no start/end boundaries; the test is within the boundaries
192+
if !isTest || (config.start == 0 && config.count == 0) || (count >= config.start && config.start+config.count > count) {
193+
res, err := c.Run(cf)
194+
if err != nil {
195+
errCh <- err
196+
}
197+
resCh <- res
198+
}
199+
if isTest {
200+
count++
190201
}
191-
resCh <- res
192202
}(c, config)
193203
}
194204

@@ -208,25 +218,15 @@ func (g *Group) Run(config RunConfig) ([]Result, error) {
208218
results = append(results, res...)
209219
}
210220
} else {
211-
for _, c := range g.Children {
212-
res, err := c.Run(config)
221+
for _, r := range runnables {
222+
res, err := r.Run(config)
213223
if err != nil {
214224
return results, err
215225
}
216226
results = append(results, res...)
217227
}
218228
}
219229

220-
if g.GroupFilePath != "" {
221-
config.Logger.Log(logger.LevelInfo, fmt.Sprintf("%s::gdeinit()", g.Name()))
222-
res, err := executeScript(g.GroupFilePath, g.Path, "", []string{"deinit"}, config)
223-
if err != nil {
224-
return results, err
225-
}
226-
if res.TestResult != Pass {
227-
return results, fmt.Errorf("Error running %s", g.GroupFilePath+":deinit")
228-
}
229-
}
230230
return results, nil
231231
}
232232

@@ -247,3 +247,37 @@ func (g *Group) willRun(config RunConfig) bool {
247247

248248
return strings.HasPrefix(config.TestPattern, g.Name()) || strings.HasPrefix(g.Name(), config.TestPattern)
249249
}
250+
251+
func min(x, y int) int {
252+
if x < y {
253+
return x
254+
}
255+
return y
256+
}
257+
258+
// calculateShard calculate the start and end of a slice to use in a given shard of a given total
259+
// slice and a given number of shards.
260+
// We split by the total list size. If it is uneven, e.g. 22 elements into 10 shards,
261+
// then the first shards will be rounded up until the reminder can be split evenly.
262+
//
263+
// e.g.
264+
// 22 elements into 10 shards will be 3, 3, 2, 2, 2, 2, 2, 2, 2, 2
265+
// 29 elements into 10 shards will be 3, 3, 3, 3, 3, 3, 3, 3, 3, 2
266+
// 30 elements into 10 shards will be 3, 3, 3, 3, 3, 3, 3, 3, 3, 3
267+
// 8 elements into 5 shards will be 2, 2, 2, 1, 1
268+
//
269+
// Ths important thing is consistency among runs using the same set of parameters
270+
// so that you can reliably get the same subset each time.
271+
func calculateShard(size, shard, totalShards int) (start, count int) {
272+
elmsPerShard := size / totalShards
273+
remainder := size % totalShards
274+
before := (shard - 1) * elmsPerShard
275+
if remainder > 0 {
276+
before += min(remainder, shard-1)
277+
}
278+
count = elmsPerShard
279+
if remainder >= shard {
280+
count++
281+
}
282+
return before, count
283+
}

local/group_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package local
2+
3+
import (
4+
"testing"
5+
)
6+
7+
func TestCalculateShard(t *testing.T) {
8+
tests := []struct {
9+
size int
10+
shard int
11+
shards int
12+
start int
13+
count int
14+
}{
15+
{22, 1, 10, 0, 3},
16+
{22, 2, 10, 3, 3},
17+
{22, 3, 10, 6, 2},
18+
{29, 10, 10, 27, 2},
19+
{2, 1, 6, 0, 1}, // more shards than elements
20+
{2, 2, 6, 1, 1}, // more shards than elements
21+
{2, 3, 6, 2, 0}, // more shards than elements
22+
}
23+
for _, tt := range tests {
24+
t.Run("", func(t *testing.T) {
25+
if gotStart, gotCount := calculateShard(tt.size, tt.shard, tt.shards); gotStart != tt.start || gotCount != tt.count {
26+
t.Errorf("calculateShard() = %v, %v, want %v, %v", gotStart, gotCount, tt.start, tt.count)
27+
}
28+
})
29+
}
30+
31+
}

local/groupcommand.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package local
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/linuxkit/rtf/logger"
7+
)
8+
9+
// Run satisfies the TestContainer interface.
10+
// Run the group init or deinit command.
11+
func (g GroupCommand) Run(config RunConfig) ([]Result, error) {
12+
config.Logger.Log(logger.LevelInfo, fmt.Sprintf("%s::%s()", g.Name, g.Type))
13+
res, err := executeScript(g.FilePath, g.Path, "", []string{g.Type}, config)
14+
if err != nil {
15+
return nil, err
16+
}
17+
if res.TestResult != Pass {
18+
return nil, fmt.Errorf("error running %s:%s", g.FilePath, g.Type)
19+
}
20+
return []Result{res}, nil
21+
}
22+
23+
// List satisfies the TestContainer interface
24+
func (g GroupCommand) List(config RunConfig) []Info {
25+
info := Info{
26+
Name: g.Name,
27+
}
28+
29+
return []Info{info}
30+
}
31+
32+
// Order returns a tests order
33+
func (g GroupCommand) Order() int {
34+
if g.Type == "init" {
35+
return 0
36+
}
37+
return 1
38+
}
39+
40+
// Gather satisfies the TestContainer interface
41+
func (g GroupCommand) Gather(config RunConfig, count int) ([]TestContainer, int) {
42+
return []TestContainer{&g}, 0
43+
}

0 commit comments

Comments
 (0)