diff --git a/tftp_test.go b/tftp_test.go index ced1fa6..63bc172 100644 --- a/tftp_test.go +++ b/tftp_test.go @@ -1058,7 +1058,7 @@ func testShutdownDuringTransfer(t *testing.T, singlePort bool) { } case <-time.After(5 * time.Second): t.Error("client did not finish in time") - } + } } func TestSetLocalAddr(t *testing.T) { @@ -1137,3 +1137,97 @@ func TestSetLocalAddr(t *testing.T) { } } } + +func TestMultipleClientsSimultaneously(t *testing.T) { + modes := []struct { + name string + singlePort bool + }{ + {"SinglePort", true}, + {"Regular", false}, + } + + for _, mode := range modes { + t.Run(mode.name, func(t *testing.T) { + s, c := makeTestServer(mode.singlePort) + defer s.Shutdown() + + var wg sync.WaitGroup + numClients := 10 + downloadsPerClient := 10 // Each client will download its file this many times + fileSize := 5 * 1024 * 1024 // 5MB files + filesData := make([][]byte, numClients) + errChan := make(chan error, numClients*downloadsPerClient) + + t.Logf("Creating %d files of %d MB each", numClients, fileSize/1024/1024) + + // Create different data for each client + for i := 0; i < numClients; i++ { + filesData[i] = make([]byte, fileSize) + rand.Read(filesData[i]) + } + + // Setup test files by uploading them to the server + t.Log("Uploading test files to server...") + for i := 0; i < numClients; i++ { + filename := fmt.Sprintf("test%d.bin", i) + sender, err := c.Send(filename, "octet") + if err != nil { + t.Fatalf("failed to start upload of test file %s: %v", filename, err) + } + _, err = sender.ReadFrom(bytes.NewReader(filesData[i])) + if err != nil { + t.Fatalf("failed to upload test file %s: %v", filename, err) + } + t.Logf("Uploaded %s (%d MB)", filename, fileSize/1024/1024) + } + + t.Logf("Starting %d clients, each downloading %d times", numClients, downloadsPerClient) + + // Start multiple clients simultaneously, but each client's downloads are sequential + for i := 0; i < numClients; i++ { + wg.Add(1) + go func(clientNum int) { + defer wg.Done() + + filename := fmt.Sprintf("test%d.bin", clientNum) + + // Each client performs its downloads sequentially + for j := 0; j < downloadsPerClient; j++ { + t.Logf("Client %d started download %d", clientNum, j) + + r, err := c.Receive(filename, "octet") + if err != nil { + errChan <- fmt.Errorf("client %d (download %d) failed to receive: %v", clientNum, j, err) + return + } + + var buf bytes.Buffer + _, err = r.WriteTo(&buf) + if err != nil { + errChan <- fmt.Errorf("client %d (download %d) failed to read data: %v", clientNum, j, err) + return + } + + if !bytes.Equal(buf.Bytes(), filesData[clientNum]) { + errChan <- fmt.Errorf("client %d (download %d) received incorrect data", clientNum, j) + return + } + + t.Logf("Client %d completed download %d", clientNum, j) + } + t.Logf("Client %d completed all %d downloads", clientNum, downloadsPerClient) + }(i) + } + + // Wait for all goroutines to finish + wg.Wait() + close(errChan) + + // Check for any errors from the goroutines + for err := range errChan { + t.Error(err) + } + }) + } +}