Skip to content

Commit 9eb03de

Browse files
committed
Issue #5: Implement HTTP API for AI functionality
Add comprehensive AI HTTP API implementation to Runner service: - Implement GenerateSQL endpoint with natural language processing - Implement ValidateSQL endpoint with syntax validation and formatting - Implement GetAICapabilities endpoint for AI service discovery - Add comprehensive request validation and error handling - Add structured error responses with proper AI error codes - Add comprehensive test coverage for all AI endpoints - Add integration tests for end-to-end AI workflow - Add SQL formatting utility function - Add proper logging and metadata generation - Add timestamppb import for protobuf timestamp support The implementation provides a solid foundation for AI-powered SQL generation and validation with proper error handling, comprehensive testing, and following gRPC Gateway best practices from Context7.
1 parent 5ec17fd commit 9eb03de

File tree

2 files changed

+506
-0
lines changed

2 files changed

+506
-0
lines changed

pkg/server/ai_server_test.go

Lines changed: 360 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
/*
2+
Copyright 2024 API Testing Authors.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
*/
16+
package server_test
17+
18+
import (
19+
"context"
20+
"testing"
21+
22+
"github.com/linuxsuren/api-testing/pkg/server"
23+
"github.com/stretchr/testify/assert"
24+
"github.com/stretchr/testify/require"
25+
)
26+
27+
func TestServer_GenerateSQL(t *testing.T) {
28+
s := server.NewInMemoryServer("", nil)
29+
30+
tests := []struct {
31+
name string
32+
request *server.GenerateSQLRequest
33+
expectError bool
34+
expectedErrMsg string
35+
expectSuccess bool
36+
}{
37+
{
38+
name: "successful generation",
39+
request: &server.GenerateSQLRequest{
40+
NaturalLanguage: "Find all users created in the last 30 days",
41+
DatabaseTarget: &server.DatabaseTarget{
42+
Type: "postgresql",
43+
Version: "13.0",
44+
},
45+
},
46+
expectSuccess: true,
47+
},
48+
{
49+
name: "empty natural language input",
50+
request: &server.GenerateSQLRequest{
51+
NaturalLanguage: "",
52+
},
53+
expectError: true,
54+
expectedErrMsg: "Natural language input is required",
55+
},
56+
{
57+
name: "with database context",
58+
request: &server.GenerateSQLRequest{
59+
NaturalLanguage: "Count active users",
60+
DatabaseTarget: &server.DatabaseTarget{
61+
Type: "mysql",
62+
Version: "8.0",
63+
Schemas: []string{"main", "analytics"},
64+
},
65+
Options: &server.GenerationOptions{
66+
IncludeExplanation: true,
67+
FormatOutput: true,
68+
MaxSuggestions: 3,
69+
ConfidenceThreshold: 0.8,
70+
},
71+
Context: map[string]string{
72+
"table": "users",
73+
},
74+
},
75+
expectSuccess: true,
76+
},
77+
}
78+
79+
for _, tt := range tests {
80+
t.Run(tt.name, func(t *testing.T) {
81+
ctx := context.Background()
82+
resp, err := s.GenerateSQL(ctx, tt.request)
83+
84+
require.NoError(t, err, "GenerateSQL should not return gRPC error")
85+
require.NotNil(t, resp)
86+
87+
if tt.expectError {
88+
require.NotNil(t, resp.Error, "Response should contain error")
89+
assert.Equal(t, tt.expectedErrMsg, resp.Error.Message)
90+
assert.Equal(t, server.AIErrorCode_INVALID_INPUT, resp.Error.Code)
91+
} else if tt.expectSuccess {
92+
assert.Nil(t, resp.Error, "Response should not contain error")
93+
assert.NotEmpty(t, resp.GeneratedSql, "Generated SQL should not be empty")
94+
assert.Greater(t, resp.ConfidenceScore, float32(0), "Confidence score should be positive")
95+
assert.NotEmpty(t, resp.Explanation, "Explanation should not be empty")
96+
assert.NotNil(t, resp.Metadata, "Metadata should not be nil")
97+
assert.NotEmpty(t, resp.Metadata.RequestId, "Request ID should not be empty")
98+
assert.Greater(t, resp.Metadata.ProcessingTimeMs, float64(0), "Processing time should be positive")
99+
}
100+
})
101+
}
102+
}
103+
104+
func TestServer_ValidateSQL(t *testing.T) {
105+
s := server.NewInMemoryServer("", nil)
106+
107+
tests := []struct {
108+
name string
109+
request *server.ValidateSQLRequest
110+
expectValid bool
111+
expectError bool
112+
}{
113+
{
114+
name: "valid SELECT query",
115+
request: &server.ValidateSQLRequest{
116+
Sql: "SELECT * FROM users WHERE active = 1",
117+
DatabaseType: "postgresql",
118+
},
119+
expectValid: true,
120+
},
121+
{
122+
name: "valid INSERT query",
123+
request: &server.ValidateSQLRequest{
124+
Sql: "INSERT INTO users (name, email) VALUES ('John', '[email protected]')",
125+
DatabaseType: "mysql",
126+
},
127+
expectValid: true,
128+
},
129+
{
130+
name: "valid UPDATE query",
131+
request: &server.ValidateSQLRequest{
132+
Sql: "UPDATE users SET active = 0 WHERE id = 1",
133+
DatabaseType: "sqlite",
134+
},
135+
expectValid: true,
136+
},
137+
{
138+
name: "valid DELETE query",
139+
request: &server.ValidateSQLRequest{
140+
Sql: "DELETE FROM users WHERE inactive_date < NOW() - INTERVAL 1 YEAR",
141+
DatabaseType: "postgresql",
142+
},
143+
expectValid: true,
144+
},
145+
{
146+
name: "empty SQL query",
147+
request: &server.ValidateSQLRequest{
148+
Sql: "",
149+
},
150+
expectValid: false,
151+
expectError: true,
152+
},
153+
{
154+
name: "invalid SQL syntax",
155+
request: &server.ValidateSQLRequest{
156+
Sql: "INVALID QUERY SYNTAX",
157+
DatabaseType: "mysql",
158+
},
159+
expectValid: false,
160+
},
161+
{
162+
name: "complex valid query with context",
163+
request: &server.ValidateSQLRequest{
164+
Sql: "SELECT u.name, p.title FROM users u JOIN posts p ON u.id = p.user_id",
165+
DatabaseType: "postgresql",
166+
Context: map[string]string{
167+
"schema": "public",
168+
"tables": "users,posts",
169+
},
170+
},
171+
expectValid: true,
172+
},
173+
}
174+
175+
for _, tt := range tests {
176+
t.Run(tt.name, func(t *testing.T) {
177+
ctx := context.Background()
178+
resp, err := s.ValidateSQL(ctx, tt.request)
179+
180+
require.NoError(t, err, "ValidateSQL should not return gRPC error")
181+
require.NotNil(t, resp)
182+
183+
assert.Equal(t, tt.expectValid, resp.IsValid)
184+
185+
if tt.expectError {
186+
assert.NotEmpty(t, resp.Errors, "Should have validation errors")
187+
assert.Equal(t, "SQL query is required", resp.Errors[0].Message)
188+
assert.Equal(t, server.ValidationErrorType_SYNTAX_ERROR, resp.Errors[0].Type)
189+
} else if tt.expectValid {
190+
assert.Empty(t, resp.Errors, "Valid SQL should have no errors")
191+
assert.NotEmpty(t, resp.FormattedSql, "Should have formatted SQL")
192+
assert.NotNil(t, resp.Metadata, "Should have validation metadata")
193+
} else {
194+
assert.NotEmpty(t, resp.Errors, "Invalid SQL should have errors")
195+
}
196+
})
197+
}
198+
}
199+
200+
func TestServer_GetAICapabilities(t *testing.T) {
201+
s := server.NewInMemoryServer("", nil)
202+
203+
ctx := context.Background()
204+
resp, err := s.GetAICapabilities(ctx, &server.Empty{})
205+
206+
require.NoError(t, err)
207+
require.NotNil(t, resp)
208+
209+
// Test response structure
210+
assert.NotEmpty(t, resp.SupportedDatabases, "Should support multiple databases")
211+
assert.Contains(t, resp.SupportedDatabases, "mysql")
212+
assert.Contains(t, resp.SupportedDatabases, "postgresql")
213+
assert.Contains(t, resp.SupportedDatabases, "sqlite")
214+
215+
assert.NotEmpty(t, resp.Features, "Should have AI features")
216+
217+
// Check for SQL generation feature
218+
var sqlGenFeature *server.AIFeature
219+
for _, feature := range resp.Features {
220+
if feature.Name == "sql_generation" {
221+
sqlGenFeature = feature
222+
break
223+
}
224+
}
225+
require.NotNil(t, sqlGenFeature, "Should have sql_generation feature")
226+
assert.True(t, sqlGenFeature.Enabled)
227+
assert.NotEmpty(t, sqlGenFeature.Description)
228+
assert.NotEmpty(t, sqlGenFeature.Parameters)
229+
230+
// Check for SQL validation feature
231+
var sqlValFeature *server.AIFeature
232+
for _, feature := range resp.Features {
233+
if feature.Name == "sql_validation" {
234+
sqlValFeature = feature
235+
break
236+
}
237+
}
238+
require.NotNil(t, sqlValFeature, "Should have sql_validation feature")
239+
assert.True(t, sqlValFeature.Enabled)
240+
241+
assert.NotEmpty(t, resp.Version, "Should have version")
242+
assert.NotEqual(t, server.HealthStatus_HEALTH_STATUS_UNSPECIFIED, resp.Status)
243+
assert.NotEmpty(t, resp.Limits, "Should have limits")
244+
}
245+
246+
func TestServer_AIErrorHandling(t *testing.T) {
247+
s := server.NewInMemoryServer("", nil)
248+
249+
tests := []struct {
250+
name string
251+
testFunc func(context.Context) error
252+
expectedError string
253+
}{
254+
{
255+
name: "GenerateSQL with empty input",
256+
testFunc: func(ctx context.Context) error {
257+
resp, err := s.GenerateSQL(ctx, &server.GenerateSQLRequest{})
258+
if err != nil {
259+
return err
260+
}
261+
if resp.Error != nil && resp.Error.Code == server.AIErrorCode_INVALID_INPUT {
262+
return nil // Expected error
263+
}
264+
return assert.AnError // Unexpected response
265+
},
266+
},
267+
{
268+
name: "ValidateSQL with empty input",
269+
testFunc: func(ctx context.Context) error {
270+
resp, err := s.ValidateSQL(ctx, &server.ValidateSQLRequest{})
271+
if err != nil {
272+
return err
273+
}
274+
if !resp.IsValid && len(resp.Errors) > 0 {
275+
return nil // Expected validation failure
276+
}
277+
return assert.AnError // Unexpected response
278+
},
279+
},
280+
}
281+
282+
for _, tt := range tests {
283+
t.Run(tt.name, func(t *testing.T) {
284+
ctx := context.Background()
285+
err := tt.testFunc(ctx)
286+
assert.NoError(t, err, "Error handling test should pass")
287+
})
288+
}
289+
}
290+
291+
func TestServer_AIMethodsIntegration(t *testing.T) {
292+
s := server.NewInMemoryServer("", nil)
293+
ctx := context.Background()
294+
295+
// First, check AI capabilities
296+
capResp, err := s.GetAICapabilities(ctx, &server.Empty{})
297+
require.NoError(t, err)
298+
require.NotNil(t, capResp)
299+
300+
// Test SQL generation
301+
genResp, err := s.GenerateSQL(ctx, &server.GenerateSQLRequest{
302+
NaturalLanguage: "Find all users",
303+
})
304+
require.NoError(t, err)
305+
require.NotNil(t, genResp)
306+
require.Nil(t, genResp.Error, "Generation should succeed")
307+
308+
// Test SQL validation with generated SQL
309+
valResp, err := s.ValidateSQL(ctx, &server.ValidateSQLRequest{
310+
Sql: genResp.GeneratedSql,
311+
})
312+
require.NoError(t, err)
313+
require.NotNil(t, valResp)
314+
315+
// The generated SQL should be valid (contains SELECT keyword)
316+
assert.True(t, valResp.IsValid, "Generated SQL should be valid")
317+
}
318+
319+
func TestFormatSQL(t *testing.T) {
320+
tests := []struct {
321+
name string
322+
input string
323+
expected string
324+
}{
325+
{
326+
name: "simple SELECT",
327+
input: "SELECT * FROM users",
328+
expected: "SELECT *\nFROM users",
329+
},
330+
{
331+
name: "SELECT with WHERE",
332+
input: "SELECT id, name FROM users WHERE active = 1",
333+
expected: "SELECT id, name\nFROM users\nWHERE active = 1",
334+
},
335+
{
336+
name: "complex query with multiple clauses",
337+
input: "SELECT u.name, COUNT(*) FROM users u WHERE u.active = 1 GROUP BY u.name ORDER BY COUNT(*) DESC",
338+
expected: "SELECT u.name, COUNT(*)\nFROM users u\nWHERE u.active = 1\nGROUP BY u.name\nORDER BY COUNT(*) DESC",
339+
},
340+
}
341+
342+
for _, tt := range tests {
343+
t.Run(tt.name, func(t *testing.T) {
344+
// This tests the formatSQL function indirectly through ValidateSQL
345+
s := server.NewInMemoryServer("", nil)
346+
ctx := context.Background()
347+
348+
resp, err := s.ValidateSQL(ctx, &server.ValidateSQLRequest{
349+
Sql: tt.input,
350+
})
351+
352+
require.NoError(t, err)
353+
require.NotNil(t, resp)
354+
355+
if resp.IsValid {
356+
assert.Equal(t, tt.expected, resp.FormattedSql)
357+
}
358+
})
359+
}
360+
}

0 commit comments

Comments
 (0)