Skip to content

Commit 588735b

Browse files
committed
[#72940] linux-client: Add ShellRunner
ShellRunner tracks all shell sessions managed by the daemon. Signed-off-by: Łukasz Kędziora <lkedziora@antmicro.com>
1 parent d485c1a commit 588735b

File tree

2 files changed

+127
-0
lines changed

2 files changed

+127
-0
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package shell
2+
3+
import (
4+
"errors"
5+
"sync"
6+
7+
"golang.org/x/sync/semaphore"
8+
)
9+
10+
// Keeps track of shell processes spawned by the daemon.
11+
type ShellRunner struct {
12+
shellCount *semaphore.Weighted // Current available shell count.
13+
sessions map[string]*ShellSession // Dict of all currently running shells.
14+
mut sync.Mutex // Protects sessions dict.
15+
}
16+
17+
var (
18+
ErrSessionNotFound = errors.New("session not found")
19+
ErrSessionLimitReached = errors.New("session limit reached")
20+
)
21+
22+
// Create a new shell runner that is capable of spawning up to limit concurrent
23+
// shell processes.
24+
func NewShellRunner(limit int) (*ShellRunner, error) {
25+
sr := new(ShellRunner)
26+
sr.shellCount = semaphore.NewWeighted(int64(limit))
27+
sr.sessions = make(map[string]*ShellSession)
28+
return sr, nil
29+
}
30+
31+
// Spawn a new shell process and associate it with a given UUID.
32+
func (s *ShellRunner) Spawn(uuid string) (*ShellSession, error) {
33+
ok := s.shellCount.TryAcquire(1)
34+
if !ok {
35+
return nil, ErrSessionLimitReached
36+
}
37+
38+
session, err := NewShellSession(uuid)
39+
if err != nil {
40+
return nil, err
41+
}
42+
43+
s.mut.Lock()
44+
defer s.mut.Unlock()
45+
s.sessions[uuid] = session
46+
47+
return session, nil
48+
}
49+
50+
// Terminate a shell session that was previously spawned. The associated shell
51+
// process is terminated if it is still running.
52+
func (s *ShellRunner) Terminate(uuid string) error {
53+
s.mut.Lock()
54+
defer s.mut.Unlock()
55+
56+
v, ok := s.sessions[uuid]
57+
if !ok {
58+
return ErrSessionNotFound
59+
}
60+
v.Close()
61+
delete(s.sessions, uuid)
62+
s.shellCount.Release(1)
63+
64+
return nil
65+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package shell
2+
3+
import (
4+
"strconv"
5+
"testing"
6+
7+
"github.com/google/uuid"
8+
"github.com/stretchr/testify/assert"
9+
)
10+
11+
const (
12+
TestSessionUuid = "ABCDEF"
13+
)
14+
15+
func TestShellRunnerBasic(t *testing.T) {
16+
sr, err := NewShellRunner(1)
17+
assert.NoError(t, err)
18+
19+
_, err = sr.Spawn(TestSessionUuid)
20+
assert.NoError(t, err)
21+
}
22+
23+
func TestShellRunnerLimits(t *testing.T) {
24+
const (
25+
LimitToTest = 4
26+
)
27+
28+
sr, err := NewShellRunner(LimitToTest)
29+
assert.NoError(t, err)
30+
31+
for i := 0; i < LimitToTest; i += 1 {
32+
_, err = sr.Spawn(strconv.Itoa(i))
33+
assert.NoError(t, err)
34+
}
35+
36+
_, err = sr.Spawn(uuid.NewString())
37+
if assert.Error(t, err) {
38+
assert.Equal(t, ErrSessionLimitReached, err)
39+
}
40+
}
41+
42+
func TestShellRunnerLimitsIncrementedAfterTermination(t *testing.T) {
43+
sr, err := NewShellRunner(1)
44+
assert.NoError(t, err)
45+
46+
id1 := uuid.NewString()
47+
id2 := uuid.NewString()
48+
49+
_, err = sr.Spawn(id1)
50+
assert.NoError(t, err)
51+
52+
_, err = sr.Spawn(id2)
53+
if assert.Error(t, err) {
54+
assert.Equal(t, ErrSessionLimitReached, err)
55+
}
56+
57+
err = sr.Terminate(id1)
58+
assert.NoError(t, err)
59+
60+
_, err = sr.Spawn(id2)
61+
assert.NoError(t, err)
62+
}

0 commit comments

Comments
 (0)