@@ -18,6 +18,7 @@ package connection
1818
1919import (
2020 "context"
21+ "fmt"
2122 "io/ioutil"
2223 "net"
2324 "os"
@@ -33,6 +34,8 @@ import (
3334
3435 "github.com/stretchr/testify/assert"
3536 "github.com/stretchr/testify/require"
37+
38+ "github.com/container-storage-interface/spec/lib/go/csi"
3639)
3740
3841func tmpDir (t * testing.T ) string {
@@ -48,11 +51,14 @@ const (
4851// startServer creates a gRPC server without any registered services.
4952// The returned address can be used to connect to it. The cleanup
5053// function stops it. It can be called multiple times.
51- func startServer (t * testing.T , tmp string ) (string , func ()) {
54+ func startServer (t * testing.T , tmp string , identity csi. IdentityServer ) (string , func ()) {
5255 addr := path .Join (tmp , serverSock )
5356 listener , err := net .Listen ("unix" , addr )
5457 require .NoError (t , err , "listening on %s" , addr )
5558 server := grpc .NewServer ()
59+ if identity != nil {
60+ csi .RegisterIdentityServer (server , identity )
61+ }
5662 var wg sync.WaitGroup
5763 wg .Add (1 )
5864 go func () {
@@ -73,7 +79,7 @@ func startServer(t *testing.T, tmp string) (string, func()) {
7379func TestConnect (t * testing.T ) {
7480 tmp := tmpDir (t )
7581 defer os .RemoveAll (tmp )
76- addr , stopServer := startServer (t , tmp )
82+ addr , stopServer := startServer (t , tmp , nil )
7783 defer stopServer ()
7884
7985 conn , err := Connect (addr )
@@ -88,7 +94,7 @@ func TestConnect(t *testing.T) {
8894func TestConnectUnix (t * testing.T ) {
8995 tmp := tmpDir (t )
9096 defer os .RemoveAll (tmp )
91- addr , stopServer := startServer (t , tmp )
97+ addr , stopServer := startServer (t , tmp , nil )
9298 defer stopServer ()
9399
94100 conn , err := Connect ("unix:///" + addr )
@@ -129,7 +135,7 @@ func TestWaitForServer(t *testing.T) {
129135 t .Logf ("sleeping %s before starting server" , delay )
130136 time .Sleep (delay )
131137 startTimeServer = time .Now ()
132- _ , stopServer = startServer (t , tmp )
138+ _ , stopServer = startServer (t , tmp , nil )
133139 }()
134140 conn , err := Connect (path .Join (tmp , serverSock ))
135141 if assert .NoError (t , err , "connect via absolute path" ) {
@@ -163,7 +169,7 @@ func TestTimout(t *testing.T) {
163169func TestReconnect (t * testing.T ) {
164170 tmp := tmpDir (t )
165171 defer os .RemoveAll (tmp )
166- addr , stopServer := startServer (t , tmp )
172+ addr , stopServer := startServer (t , tmp , nil )
167173 defer func () {
168174 stopServer ()
169175 }()
@@ -190,7 +196,7 @@ func TestReconnect(t *testing.T) {
190196 }
191197
192198 // No reconnection either when the server comes back.
193- _ , stopServer = startServer (t , tmp )
199+ _ , stopServer = startServer (t , tmp , nil )
194200 // We need to give gRPC some time. It does not attempt to reconnect
195201 // immediately. If we send the method call too soon, the test passes
196202 // even though a later method call will go through again.
@@ -208,7 +214,7 @@ func TestReconnect(t *testing.T) {
208214func TestDisconnect (t * testing.T ) {
209215 tmp := tmpDir (t )
210216 defer os .RemoveAll (tmp )
211- addr , stopServer := startServer (t , tmp )
217+ addr , stopServer := startServer (t , tmp , nil )
212218 defer func () {
213219 stopServer ()
214220 }()
@@ -239,7 +245,7 @@ func TestDisconnect(t *testing.T) {
239245 }
240246
241247 // No reconnection either when the server comes back.
242- _ , stopServer = startServer (t , tmp )
248+ _ , stopServer = startServer (t , tmp , nil )
243249 // We need to give gRPC some time. It does not attempt to reconnect
244250 // immediately. If we send the method call too soon, the test passes
245251 // even though a later method call will go through again.
@@ -259,7 +265,7 @@ func TestDisconnect(t *testing.T) {
259265func TestExplicitReconnect (t * testing.T ) {
260266 tmp := tmpDir (t )
261267 defer os .RemoveAll (tmp )
262- addr , stopServer := startServer (t , tmp )
268+ addr , stopServer := startServer (t , tmp , nil )
263269 defer func () {
264270 stopServer ()
265271 }()
@@ -290,7 +296,7 @@ func TestExplicitReconnect(t *testing.T) {
290296 }
291297
292298 // No reconnection either when the server comes back.
293- _ , stopServer = startServer (t , tmp )
299+ _ , stopServer = startServer (t , tmp , nil )
294300 // We need to give gRPC some time. It does not attempt to reconnect
295301 // immediately. If we send the method call too soon, the test passes
296302 // even though a later method call will go through again.
@@ -306,3 +312,87 @@ func TestExplicitReconnect(t *testing.T) {
306312 assert .Equal (t , 1 , reconnectCount , "connection loss callback should be called once" )
307313 }
308314}
315+
316+ func TestGetDriverName (t * testing.T ) {
317+ tests := []struct {
318+ name string
319+ output * csi.GetPluginInfoResponse
320+ injectError bool
321+ expectError bool
322+ }{
323+ {
324+ name : "success" ,
325+ output : & csi.GetPluginInfoResponse {
326+ Name : "csi/example" ,
327+ VendorVersion : "0.2.0" ,
328+ Manifest : map [string ]string {
329+ "hello" : "world" ,
330+ },
331+ },
332+ expectError : false ,
333+ },
334+ {
335+ name : "gRPC error" ,
336+ output : nil ,
337+ injectError : true ,
338+ expectError : true ,
339+ },
340+ {
341+ name : "empty name" ,
342+ output : & csi.GetPluginInfoResponse {
343+ Name : "" ,
344+ },
345+ expectError : true ,
346+ },
347+ }
348+
349+ for _ , test := range tests {
350+ t .Run (test .name , func (t * testing.T ) {
351+ out := test .output
352+ var injectedErr error
353+ if test .injectError {
354+ injectedErr = fmt .Errorf ("mock error" )
355+ }
356+
357+ tmp := tmpDir (t )
358+ defer os .RemoveAll (tmp )
359+ identity := & identityServer {out , injectedErr }
360+ addr , stopServer := startServer (t , tmp , identity )
361+ defer func () {
362+ stopServer ()
363+ }()
364+
365+ conn , err := Connect (addr )
366+
367+ name , err := GetDriverName (context .Background (), conn )
368+ if test .expectError && err == nil {
369+ t .Errorf ("test %q: Expected error, got none" , test .name )
370+ }
371+ if ! test .expectError && err != nil {
372+ t .Errorf ("test %q: got error: %v" , test .name , err )
373+ }
374+ if err == nil && name != "csi/example" {
375+ t .Errorf ("got unexpected name: %q" , name )
376+ }
377+ })
378+ }
379+ }
380+
381+ type identityServer struct {
382+ response * csi.GetPluginInfoResponse
383+ err error
384+ }
385+
386+ var _ csi.IdentityServer = & identityServer {}
387+
388+ func (i * identityServer ) GetPluginCapabilities (context.Context , * csi.GetPluginCapabilitiesRequest ) (* csi.GetPluginCapabilitiesResponse , error ) {
389+ return nil , fmt .Errorf ("Not implemented" )
390+ }
391+
392+ func (i * identityServer ) GetPluginInfo (context.Context , * csi.GetPluginInfoRequest ) (* csi.GetPluginInfoResponse , error ) {
393+ return i .response , i .err
394+ }
395+
396+ func (i * identityServer ) Probe (context.Context , * csi.ProbeRequest ) (* csi.ProbeResponse , error ) {
397+ return nil , fmt .Errorf ("Not implemented" )
398+ }
0 commit comments