Skip to content

Commit dce2c5b

Browse files
Query assessment integration in app code collector (#1162)
* minor changes * base changes for query assessment - app code collector * minor changes * minor fixes * minor change * addressing comments * minor change
1 parent 45e64f7 commit dce2c5b

14 files changed

+754
-401
lines changed

assessment/assessment_engine.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ func performAppAssessment(ctx context.Context, collectors assessmentCollectors)
230230
}
231231

232232
logger.Log.Info("starting app assessment...")
233-
codeAssessment, err := collectors.appAssessmentCollector.AnalyzeProject(ctx)
233+
codeAssessment, _, err := collectors.appAssessmentCollector.AnalyzeProject(ctx)
234234

235235
if err != nil {
236236
logger.Log.Error("error analyzing project", zap.Error(err))

assessment/collectors/app_code_collector.go

Lines changed: 101 additions & 53 deletions
Large diffs are not rendered by default.

assessment/collectors/app_code_collector_test.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,10 @@ func TestExtractPublicMethodSignatures(t *testing.T) {
147147
assert.Error(t, err)
148148

149149
missingKeyJSON := `{"other_key": "value"}`
150-
_, err = summarizer.extractPublicMethodSignatures(missingKeyJSON)
151-
assert.Error(t, err)
150+
signatures, err = summarizer.extractPublicMethodSignatures(missingKeyJSON)
151+
assert.NoError(t, err)
152+
assert.NotNil(t, signatures)
153+
assert.Len(t, signatures, 0)
152154
}
153155

154156
func TestFormatQuestionsAndSearchResults(t *testing.T) {
@@ -157,13 +159,17 @@ func TestFormatQuestionsAndSearchResults(t *testing.T) {
157159
{"Use Connection A.", "Use Connection B."},
158160
{"Use Write-Op C."},
159161
}
162+
querySearchResults := [][]string{
163+
{},
164+
{},
165+
}
160166

161-
formatted := formatQuestionsAndSearchResults(questions, searchResults)
167+
formatted := formatQuestionsAndSearchResults(questions, searchResults, querySearchResults)
162168
assert.Contains(t, formatted, "* **Question 1:** How to connect?")
163-
assert.Contains(t, formatted, "* **Potential Solution 1:** Use Connection A.")
164-
assert.Contains(t, formatted, "* **Potential Solution 2:** Use Connection B.")
169+
assert.Contains(t, formatted, "* **Potential Code Solution 1:** Use Connection A.")
170+
assert.Contains(t, formatted, "* **Potential Code Solution 2:** Use Connection B.")
165171
assert.Contains(t, formatted, "* **Question 2:** How to write?")
166-
assert.Contains(t, formatted, "* **Potential Solution 1:** Use Write-Op C.")
172+
assert.Contains(t, formatted, "* **Potential Code Solution 1:** Use Write-Op C.")
167173
}
168174

169175
func TestAnalyzeFileDependencies(t *testing.T) {

assessment/collectors/embeddings/generate_embedding.go

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ var javaMysqlMigrationConcept []byte
3838
//go:embed vertx_concept_examples.json
3939
var vertxMysqlMigrationConcept []byte
4040

41+
//go:embed mysql_query_examples.json
42+
var mysqlQueryExamples []byte
43+
4144
//go:embed hibernate_concept_examples.json
4245
var hibernateMysqlMigrationConcept []byte
4346

@@ -60,21 +63,7 @@ type PredictionClientInterface interface {
6063
Close() error
6164
}
6265

63-
func createEmbededTextsFromFile(project, location, sourceTargetFramework string) ([]MySqlMigrationConcept, error) {
64-
ctx := context.Background()
65-
apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
66-
model := "text-embedding-preview-0815"
67-
68-
client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
69-
if err != nil {
70-
return nil, err
71-
}
72-
defer client.Close()
73-
74-
return createEmbededTextsWithClient(ctx, client, project, location, model, sourceTargetFramework)
75-
}
76-
77-
func createEmbededTextsWithClient(ctx context.Context, client PredictionClientInterface, project, location, model, sourceTargetFramework string) ([]MySqlMigrationConcept, error) {
66+
func createCodeSampleEmbeddings(ctx context.Context, client PredictionClientInterface, project, location, model, sourceTargetFramework string) ([]MySqlMigrationConcept, error) {
7867
var data []byte
7968
switch sourceTargetFramework {
8069
case "go-sql-driver/mysql_go-sql-spanner":
@@ -93,7 +82,18 @@ func createEmbededTextsWithClient(ctx context.Context, client PredictionClientIn
9382
if err := json.Unmarshal(data, &concepts); err != nil {
9483
return nil, err
9584
}
85+
return attachEmbeddings(ctx, client, project, location, model, concepts)
86+
}
9687

88+
func createQuerySampleEmbeddings(ctx context.Context, client PredictionClientInterface, project, location, model string) ([]MySqlMigrationConcept, error) {
89+
var queryExamples []MySqlMigrationConcept
90+
if err := json.Unmarshal(mysqlQueryExamples, &queryExamples); err != nil {
91+
return nil, fmt.Errorf("failed to parse MySQL query examples JSON: %w", err)
92+
}
93+
return attachEmbeddings(ctx, client, project, location, model, queryExamples)
94+
}
95+
96+
func attachEmbeddings(ctx context.Context, client PredictionClientInterface, project, location, model string, concepts []MySqlMigrationConcept) ([]MySqlMigrationConcept, error) {
9797
var instances []*structpb.Value
9898
for _, c := range concepts {
9999
instances = append(instances, structpb.NewStructValue(&structpb.Struct{
@@ -116,11 +116,9 @@ func createEmbededTextsWithClient(ctx context.Context, client PredictionClientIn
116116

117117
for i, prediction := range resp.Predictions {
118118
values := prediction.GetStructValue().GetFields()["embeddings"].GetStructValue().GetFields()["values"].GetListValue().GetValues()
119-
120119
if values == nil {
121120
continue
122121
}
123-
124122
embedding := make([]float32, len(values))
125123
for j, v := range values {
126124
if v == nil {
@@ -132,3 +130,16 @@ func createEmbededTextsWithClient(ctx context.Context, client PredictionClientIn
132130
}
133131
return concepts, nil
134132
}
133+
134+
// Helper to create a new Vertex AI Prediction client and return context, client, and model
135+
func newAIPredictionClient(location string) (context.Context, PredictionClientInterface, string, error) {
136+
ctx := context.Background()
137+
apiEndpoint := fmt.Sprintf("%s-aiplatform.googleapis.com:443", location)
138+
model := "text-embedding-preview-0815"
139+
140+
client, err := aiplatform.NewPredictionClient(ctx, option.WithEndpoint(apiEndpoint))
141+
if err != nil {
142+
return nil, nil, "", err
143+
}
144+
return ctx, client, model, nil
145+
}

assessment/collectors/embeddings/generate_embedding_test.go

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ func (f *fakeClient) Close() error {
7272
return nil
7373
}
7474

75-
func TestCreateEmbededTextsWithClient(t *testing.T) {
75+
func TestCreateCodeSampleEmbeddings(t *testing.T) {
7676
ctx := context.Background()
7777

7878
goMysqlMigrationConcept = []byte(`[
@@ -87,14 +87,14 @@ func TestCreateEmbededTextsWithClient(t *testing.T) {
8787
]`)
8888

8989
client := &fakeClient{}
90-
concepts, err := createEmbededTextsWithClient(ctx, client, "test-proj", "us-central1", "mock-model", "go-sql-driver/mysql_go-sql-spanner")
90+
concepts, err := createCodeSampleEmbeddings(ctx, client, "test-proj", "us-central1", "mock-model", "go-sql-driver/mysql_go-sql-spanner")
9191

9292
assert.NoError(t, err)
9393
assert.Len(t, concepts, 1)
9494
assert.InDeltaSlice(t, []float32{0.1, 0.2, 0.3}, concepts[0].Embedding, 0.001)
9595
}
9696

97-
func TestCreateEmbededTextsWithClientJava(t *testing.T) {
97+
func TestCreateCodeSampleEmbeddingsJava(t *testing.T) {
9898
ctx := context.Background()
9999

100100
javaMysqlMigrationConcept = []byte(`[
@@ -109,41 +109,41 @@ func TestCreateEmbededTextsWithClientJava(t *testing.T) {
109109
]`)
110110

111111
client := &fakeClient{}
112-
concepts, err := createEmbededTextsWithClient(ctx, client, "test-proj", "us-central1", "mock-model", "jdbc_jdbc")
112+
concepts, err := createCodeSampleEmbeddings(ctx, client, "test-proj", "us-central1", "mock-model", "jdbc_jdbc")
113113

114114
assert.NoError(t, err)
115115
assert.Len(t, concepts, 1)
116116
assert.InDeltaSlice(t, []float32{0.1, 0.2, 0.3}, concepts[0].Embedding, 0.001)
117117
}
118118

119-
func TestCreateEmbededTextsWithClient_UnsupportedLanguage(t *testing.T) {
119+
func TestCreateCodeSampleEmbeddings_UnsupportedLanguage(t *testing.T) {
120120
ctx := context.Background()
121121
client := &fakeClient{}
122122

123-
concepts, err := createEmbededTextsWithClient(ctx, client, "test-proj", "us-central1", "mock-model", "python")
123+
concepts, err := createCodeSampleEmbeddings(ctx, client, "test-proj", "us-central1", "mock-model", "python")
124124

125125
assert.Nil(t, concepts)
126126
assert.Error(t, err)
127127
assert.Contains(t, err.Error(), "unsupported sourceTargetFramework")
128128
}
129-
func TestCreateEmbededTextsWithClient_PredictError(t *testing.T) {
129+
func TestCreateCodeSampleEmbeddings_PredictError(t *testing.T) {
130130
ctx := context.Background()
131131
client := &fakeClient{predictErr: errors.New("predict failure")}
132132

133-
_, err := createEmbededTextsWithClient(ctx, client, "test-proj", "us-central1", "mock-model", "go-sql-driver/mysql_go-sql-spanner")
133+
_, err := createCodeSampleEmbeddings(ctx, client, "test-proj", "us-central1", "mock-model", "go-sql-driver/mysql_go-sql-spanner")
134134
assert.Error(t, err)
135135
assert.Contains(t, err.Error(), "predict failure")
136136
}
137137

138-
func TestCreateEmbededTextsWithClient_InvalidJSON(t *testing.T) {
138+
func TestCreateCodeSampleEmbeddings_InvalidJSON(t *testing.T) {
139139
ctx := context.Background()
140140
// Temporarily assign invalid JSON
141141
oldGoConcept := goMysqlMigrationConcept
142142
goMysqlMigrationConcept = []byte("invalid json")
143143
defer func() { goMysqlMigrationConcept = oldGoConcept }()
144144

145145
client := &fakeClient{}
146-
_, err := createEmbededTextsWithClient(ctx, client, "test-proj", "us-central1", "mock-model", "go-sql-driver/mysql_go-sql-spanner")
146+
_, err := createCodeSampleEmbeddings(ctx, client, "test-proj", "us-central1", "mock-model", "go-sql-driver/mysql_go-sql-spanner")
147147
assert.Error(t, err)
148148
assert.Contains(t, err.Error(), "invalid character")
149149
}
@@ -154,3 +154,25 @@ func TestFakeClient_CloseCalled(t *testing.T) {
154154
assert.NoError(t, err)
155155
assert.True(t, client.closeCalled)
156156
}
157+
158+
func TestCreateQueryExampleEmbeddingsWithClient(t *testing.T) {
159+
oldMysqlQueryExamples := mysqlQueryExamples
160+
defer func() { mysqlQueryExamples = oldMysqlQueryExamples }()
161+
mysqlQueryExamples = []byte(`[
162+
{
163+
"id": "1",
164+
"example": "SELECT * FROM employees",
165+
"rewrite": {
166+
"theory": "simple select",
167+
"options": [{"mysql_code": "SELECT * FROM employees", "spanner_code": "SELECT * FROM employees"}]
168+
}
169+
}
170+
]`)
171+
ctx := context.Background()
172+
client := &fakeClient{}
173+
concepts, err := createQuerySampleEmbeddings(ctx, client, "test-proj", "us-central1", "mock-model")
174+
assert.NoError(t, err)
175+
assert.Len(t, concepts, 1)
176+
assert.Equal(t, "1", concepts[0].ID)
177+
assert.InDeltaSlice(t, []float32{0.1, 0.2, 0.3}, concepts[0].Embedding, 0.001)
178+
}

0 commit comments

Comments
 (0)