1+ /*
2+ Copyright 2024 API Testing Authors.
3+
4+ Licensed under the Apache License, Version 2.0 (the "License");
5+ you may not use this file except in compliance with the License.
6+ You may obtain a copy of the License at
7+
8+ http://www.apache.org/licenses/LICENSE-2.0
9+
10+ Unless required by applicable law or agreed to in writing, software
11+ distributed under the License is distributed on an "AS IS" BASIS,
12+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+ See the License for the specific language governing permissions and
14+ limitations under the License.
15+ */
16+ package server_test
17+
18+ import (
19+ "context"
20+ "testing"
21+
22+ "github.com/linuxsuren/api-testing/pkg/server"
23+ "github.com/stretchr/testify/assert"
24+ "github.com/stretchr/testify/require"
25+ )
26+
27+ func TestServer_GenerateSQL (t * testing.T ) {
28+ s := server .NewInMemoryServer ("" , nil )
29+
30+ tests := []struct {
31+ name string
32+ request * server.GenerateSQLRequest
33+ expectError bool
34+ expectedErrMsg string
35+ expectSuccess bool
36+ }{
37+ {
38+ name : "successful generation" ,
39+ request : & server.GenerateSQLRequest {
40+ NaturalLanguage : "Find all users created in the last 30 days" ,
41+ DatabaseTarget : & server.DatabaseTarget {
42+ Type : "postgresql" ,
43+ Version : "13.0" ,
44+ },
45+ },
46+ expectSuccess : true ,
47+ },
48+ {
49+ name : "empty natural language input" ,
50+ request : & server.GenerateSQLRequest {
51+ NaturalLanguage : "" ,
52+ },
53+ expectError : true ,
54+ expectedErrMsg : "Natural language input is required" ,
55+ },
56+ {
57+ name : "with database context" ,
58+ request : & server.GenerateSQLRequest {
59+ NaturalLanguage : "Count active users" ,
60+ DatabaseTarget : & server.DatabaseTarget {
61+ Type : "mysql" ,
62+ Version : "8.0" ,
63+ Schemas : []string {"main" , "analytics" },
64+ },
65+ Options : & server.GenerationOptions {
66+ IncludeExplanation : true ,
67+ FormatOutput : true ,
68+ MaxSuggestions : 3 ,
69+ ConfidenceThreshold : 0.8 ,
70+ },
71+ Context : map [string ]string {
72+ "table" : "users" ,
73+ },
74+ },
75+ expectSuccess : true ,
76+ },
77+ }
78+
79+ for _ , tt := range tests {
80+ t .Run (tt .name , func (t * testing.T ) {
81+ ctx := context .Background ()
82+ resp , err := s .GenerateSQL (ctx , tt .request )
83+
84+ require .NoError (t , err , "GenerateSQL should not return gRPC error" )
85+ require .NotNil (t , resp )
86+
87+ if tt .expectError {
88+ require .NotNil (t , resp .Error , "Response should contain error" )
89+ assert .Equal (t , tt .expectedErrMsg , resp .Error .Message )
90+ assert .Equal (t , server .AIErrorCode_INVALID_INPUT , resp .Error .Code )
91+ } else if tt .expectSuccess {
92+ assert .Nil (t , resp .Error , "Response should not contain error" )
93+ assert .NotEmpty (t , resp .GeneratedSql , "Generated SQL should not be empty" )
94+ assert .Greater (t , resp .ConfidenceScore , float32 (0 ), "Confidence score should be positive" )
95+ assert .NotEmpty (t , resp .Explanation , "Explanation should not be empty" )
96+ assert .NotNil (t , resp .Metadata , "Metadata should not be nil" )
97+ assert .NotEmpty (t , resp .Metadata .RequestId , "Request ID should not be empty" )
98+ assert .Greater (t , resp .Metadata .ProcessingTimeMs , float64 (0 ), "Processing time should be positive" )
99+ }
100+ })
101+ }
102+ }
103+
104+ func TestServer_ValidateSQL (t * testing.T ) {
105+ s := server .NewInMemoryServer ("" , nil )
106+
107+ tests := []struct {
108+ name string
109+ request * server.ValidateSQLRequest
110+ expectValid bool
111+ expectError bool
112+ }{
113+ {
114+ name : "valid SELECT query" ,
115+ request : & server.ValidateSQLRequest {
116+ Sql : "SELECT * FROM users WHERE active = 1" ,
117+ DatabaseType : "postgresql" ,
118+ },
119+ expectValid : true ,
120+ },
121+ {
122+ name : "valid INSERT query" ,
123+ request : & server.ValidateSQLRequest {
124+ Sql :
"INSERT INTO users (name, email) VALUES ('John', '[email protected] ')" ,
125+ DatabaseType : "mysql" ,
126+ },
127+ expectValid : true ,
128+ },
129+ {
130+ name : "valid UPDATE query" ,
131+ request : & server.ValidateSQLRequest {
132+ Sql : "UPDATE users SET active = 0 WHERE id = 1" ,
133+ DatabaseType : "sqlite" ,
134+ },
135+ expectValid : true ,
136+ },
137+ {
138+ name : "valid DELETE query" ,
139+ request : & server.ValidateSQLRequest {
140+ Sql : "DELETE FROM users WHERE inactive_date < NOW() - INTERVAL 1 YEAR" ,
141+ DatabaseType : "postgresql" ,
142+ },
143+ expectValid : true ,
144+ },
145+ {
146+ name : "empty SQL query" ,
147+ request : & server.ValidateSQLRequest {
148+ Sql : "" ,
149+ },
150+ expectValid : false ,
151+ expectError : true ,
152+ },
153+ {
154+ name : "invalid SQL syntax" ,
155+ request : & server.ValidateSQLRequest {
156+ Sql : "INVALID QUERY SYNTAX" ,
157+ DatabaseType : "mysql" ,
158+ },
159+ expectValid : false ,
160+ },
161+ {
162+ name : "complex valid query with context" ,
163+ request : & server.ValidateSQLRequest {
164+ Sql : "SELECT u.name, p.title FROM users u JOIN posts p ON u.id = p.user_id" ,
165+ DatabaseType : "postgresql" ,
166+ Context : map [string ]string {
167+ "schema" : "public" ,
168+ "tables" : "users,posts" ,
169+ },
170+ },
171+ expectValid : true ,
172+ },
173+ }
174+
175+ for _ , tt := range tests {
176+ t .Run (tt .name , func (t * testing.T ) {
177+ ctx := context .Background ()
178+ resp , err := s .ValidateSQL (ctx , tt .request )
179+
180+ require .NoError (t , err , "ValidateSQL should not return gRPC error" )
181+ require .NotNil (t , resp )
182+
183+ assert .Equal (t , tt .expectValid , resp .IsValid )
184+
185+ if tt .expectError {
186+ assert .NotEmpty (t , resp .Errors , "Should have validation errors" )
187+ assert .Equal (t , "SQL query is required" , resp .Errors [0 ].Message )
188+ assert .Equal (t , server .ValidationErrorType_SYNTAX_ERROR , resp .Errors [0 ].Type )
189+ } else if tt .expectValid {
190+ assert .Empty (t , resp .Errors , "Valid SQL should have no errors" )
191+ assert .NotEmpty (t , resp .FormattedSql , "Should have formatted SQL" )
192+ assert .NotNil (t , resp .Metadata , "Should have validation metadata" )
193+ } else {
194+ assert .NotEmpty (t , resp .Errors , "Invalid SQL should have errors" )
195+ }
196+ })
197+ }
198+ }
199+
200+ func TestServer_GetAICapabilities (t * testing.T ) {
201+ s := server .NewInMemoryServer ("" , nil )
202+
203+ ctx := context .Background ()
204+ resp , err := s .GetAICapabilities (ctx , & server.Empty {})
205+
206+ require .NoError (t , err )
207+ require .NotNil (t , resp )
208+
209+ // Test response structure
210+ assert .NotEmpty (t , resp .SupportedDatabases , "Should support multiple databases" )
211+ assert .Contains (t , resp .SupportedDatabases , "mysql" )
212+ assert .Contains (t , resp .SupportedDatabases , "postgresql" )
213+ assert .Contains (t , resp .SupportedDatabases , "sqlite" )
214+
215+ assert .NotEmpty (t , resp .Features , "Should have AI features" )
216+
217+ // Check for SQL generation feature
218+ var sqlGenFeature * server.AIFeature
219+ for _ , feature := range resp .Features {
220+ if feature .Name == "sql_generation" {
221+ sqlGenFeature = feature
222+ break
223+ }
224+ }
225+ require .NotNil (t , sqlGenFeature , "Should have sql_generation feature" )
226+ assert .True (t , sqlGenFeature .Enabled )
227+ assert .NotEmpty (t , sqlGenFeature .Description )
228+ assert .NotEmpty (t , sqlGenFeature .Parameters )
229+
230+ // Check for SQL validation feature
231+ var sqlValFeature * server.AIFeature
232+ for _ , feature := range resp .Features {
233+ if feature .Name == "sql_validation" {
234+ sqlValFeature = feature
235+ break
236+ }
237+ }
238+ require .NotNil (t , sqlValFeature , "Should have sql_validation feature" )
239+ assert .True (t , sqlValFeature .Enabled )
240+
241+ assert .NotEmpty (t , resp .Version , "Should have version" )
242+ assert .NotEqual (t , server .HealthStatus_HEALTH_STATUS_UNSPECIFIED , resp .Status )
243+ assert .NotEmpty (t , resp .Limits , "Should have limits" )
244+ }
245+
246+ func TestServer_AIErrorHandling (t * testing.T ) {
247+ s := server .NewInMemoryServer ("" , nil )
248+
249+ tests := []struct {
250+ name string
251+ testFunc func (context.Context ) error
252+ expectedError string
253+ }{
254+ {
255+ name : "GenerateSQL with empty input" ,
256+ testFunc : func (ctx context.Context ) error {
257+ resp , err := s .GenerateSQL (ctx , & server.GenerateSQLRequest {})
258+ if err != nil {
259+ return err
260+ }
261+ if resp .Error != nil && resp .Error .Code == server .AIErrorCode_INVALID_INPUT {
262+ return nil // Expected error
263+ }
264+ return assert .AnError // Unexpected response
265+ },
266+ },
267+ {
268+ name : "ValidateSQL with empty input" ,
269+ testFunc : func (ctx context.Context ) error {
270+ resp , err := s .ValidateSQL (ctx , & server.ValidateSQLRequest {})
271+ if err != nil {
272+ return err
273+ }
274+ if ! resp .IsValid && len (resp .Errors ) > 0 {
275+ return nil // Expected validation failure
276+ }
277+ return assert .AnError // Unexpected response
278+ },
279+ },
280+ }
281+
282+ for _ , tt := range tests {
283+ t .Run (tt .name , func (t * testing.T ) {
284+ ctx := context .Background ()
285+ err := tt .testFunc (ctx )
286+ assert .NoError (t , err , "Error handling test should pass" )
287+ })
288+ }
289+ }
290+
291+ func TestServer_AIMethodsIntegration (t * testing.T ) {
292+ s := server .NewInMemoryServer ("" , nil )
293+ ctx := context .Background ()
294+
295+ // First, check AI capabilities
296+ capResp , err := s .GetAICapabilities (ctx , & server.Empty {})
297+ require .NoError (t , err )
298+ require .NotNil (t , capResp )
299+
300+ // Test SQL generation
301+ genResp , err := s .GenerateSQL (ctx , & server.GenerateSQLRequest {
302+ NaturalLanguage : "Find all users" ,
303+ })
304+ require .NoError (t , err )
305+ require .NotNil (t , genResp )
306+ require .Nil (t , genResp .Error , "Generation should succeed" )
307+
308+ // Test SQL validation with generated SQL
309+ valResp , err := s .ValidateSQL (ctx , & server.ValidateSQLRequest {
310+ Sql : genResp .GeneratedSql ,
311+ })
312+ require .NoError (t , err )
313+ require .NotNil (t , valResp )
314+
315+ // The generated SQL should be valid (contains SELECT keyword)
316+ assert .True (t , valResp .IsValid , "Generated SQL should be valid" )
317+ }
318+
319+ func TestFormatSQL (t * testing.T ) {
320+ tests := []struct {
321+ name string
322+ input string
323+ expected string
324+ }{
325+ {
326+ name : "simple SELECT" ,
327+ input : "SELECT * FROM users" ,
328+ expected : "SELECT *\n FROM users" ,
329+ },
330+ {
331+ name : "SELECT with WHERE" ,
332+ input : "SELECT id, name FROM users WHERE active = 1" ,
333+ expected : "SELECT id, name\n FROM users\n WHERE active = 1" ,
334+ },
335+ {
336+ name : "complex query with multiple clauses" ,
337+ input : "SELECT u.name, COUNT(*) FROM users u WHERE u.active = 1 GROUP BY u.name ORDER BY COUNT(*) DESC" ,
338+ expected : "SELECT u.name, COUNT(*)\n FROM users u\n WHERE u.active = 1\n GROUP BY u.name\n ORDER BY COUNT(*) DESC" ,
339+ },
340+ }
341+
342+ for _ , tt := range tests {
343+ t .Run (tt .name , func (t * testing.T ) {
344+ // This tests the formatSQL function indirectly through ValidateSQL
345+ s := server .NewInMemoryServer ("" , nil )
346+ ctx := context .Background ()
347+
348+ resp , err := s .ValidateSQL (ctx , & server.ValidateSQLRequest {
349+ Sql : tt .input ,
350+ })
351+
352+ require .NoError (t , err )
353+ require .NotNil (t , resp )
354+
355+ if resp .IsValid {
356+ assert .Equal (t , tt .expected , resp .FormattedSql )
357+ }
358+ })
359+ }
360+ }
0 commit comments