Skip to content

Commit 3754d53

Browse files
EFRS-1333: Added tests
1 parent 130ff41 commit 3754d53

File tree

7 files changed

+292
-5
lines changed

7 files changed

+292
-5
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
package com.exadel.frs.core.trainservice.dto;
22

3-
public abstract class EmbeddingsProcessResponse {
3+
public interface EmbeddingsProcessResponse {
44

55
}

java/api/src/main/java/com/exadel/frs/core/trainservice/dto/EmbeddingsRecognitionProcessResponse.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
@Setter
1111
@NoArgsConstructor
1212
@AllArgsConstructor
13-
public class EmbeddingsRecognitionProcessResponse extends EmbeddingsProcessResponse {
13+
public class EmbeddingsRecognitionProcessResponse implements EmbeddingsProcessResponse {
1414

1515
private List<EmbeddingRecognitionProcessResult> results;
1616
}

java/api/src/main/java/com/exadel/frs/core/trainservice/dto/EmbeddingsVerificationProcessResponse.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
@Setter
1111
@NoArgsConstructor
1212
@AllArgsConstructor
13-
public class EmbeddingsVerificationProcessResponse extends EmbeddingsProcessResponse {
13+
public class EmbeddingsVerificationProcessResponse implements EmbeddingsProcessResponse {
1414

1515
private List<EmbeddingVerificationProcessResult> results;
1616
}

java/api/src/test/java/com/exadel/frs/core/trainservice/DbHelper.java

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,13 @@ public class DbHelper {
4444
ImgRepository imgRepository;
4545

4646
public Model insertModel() {
47-
final String apiKey = UUID.randomUUID().toString();
47+
return insertModel(ModelType.RECOGNITION);
48+
}
4849

50+
public Model insertModel(ModelType type) {
51+
var apiKey = UUID.randomUUID().toString();
4952
var app = appRepository.save(makeApp(apiKey));
50-
return modelRepository.save(makeModel(apiKey, ModelType.RECOGNITION, app));
53+
return modelRepository.save(makeModel(apiKey, type, app));
5154
}
5255

5356
public Subject insertSubject(Model model, String subjectName) {
Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package com.exadel.frs.core.trainservice.service;
2+
3+
import static com.exadel.frs.core.trainservice.system.global.Constants.PREDICTION_COUNT;
4+
import static org.assertj.core.api.Assertions.assertThat;
5+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
6+
import static org.mockito.ArgumentMatchers.any;
7+
import static org.mockito.ArgumentMatchers.anyInt;
8+
import static org.mockito.Mockito.when;
9+
import com.exadel.frs.commonservice.exception.IncorrectPredictionCountException;
10+
import com.exadel.frs.commonservice.repository.EmbeddingRepository;
11+
import com.exadel.frs.commonservice.sdk.faces.FacesApiClient;
12+
import com.exadel.frs.core.trainservice.DbHelper;
13+
import com.exadel.frs.core.trainservice.EmbeddedPostgreSQLTest;
14+
import com.exadel.frs.core.trainservice.component.FaceClassifierPredictor;
15+
import com.exadel.frs.core.trainservice.dto.ProcessEmbeddingsParams;
16+
import com.exadel.frs.core.trainservice.mapper.FacesMapper;
17+
import com.exadel.frs.core.trainservice.repository.AppRepository;
18+
import com.exadel.frs.core.trainservice.validation.ImageExtensionValidator;
19+
import java.util.Collections;
20+
import java.util.List;
21+
import org.apache.commons.lang3.tuple.Pair;
22+
import org.junit.jupiter.api.BeforeEach;
23+
import org.junit.jupiter.api.Test;
24+
import org.junit.jupiter.params.ParameterizedTest;
25+
import org.junit.jupiter.params.provider.ValueSource;
26+
import org.springframework.beans.factory.annotation.Autowired;
27+
import org.springframework.boot.test.mock.mockito.MockBean;
28+
import org.springframework.transaction.annotation.Transactional;
29+
30+
@Transactional
31+
class EmbeddingsRecognizeProcessServiceImplTest extends EmbeddedPostgreSQLTest {
32+
33+
@Autowired
34+
private DbHelper dbHelper;
35+
36+
@Autowired
37+
private ImageExtensionValidator imageExtensionValidator;
38+
39+
@Autowired
40+
private FacesMapper facesMapper;
41+
42+
@Autowired
43+
private AppRepository appRepository;
44+
45+
@Autowired
46+
private EmbeddingRepository embeddingRepository;
47+
48+
@MockBean
49+
private FacesApiClient facesApiClient;
50+
51+
@MockBean
52+
private FaceClassifierPredictor predictor;
53+
54+
@Autowired
55+
private EmbeddingsRecognizeProcessServiceImpl recognizeProcessService;
56+
57+
@BeforeEach
58+
void cleanUp() {
59+
appRepository.deleteAll();
60+
appRepository.flush();
61+
}
62+
63+
@Test
64+
void processEmbeddings_TheInputEmbeddingExistsInTheDatabase_ShouldReturnCompleteSimilarity() {
65+
var model = dbHelper.insertModel();
66+
var subject = dbHelper.insertSubject(model, "subject");
67+
var embedding = dbHelper.insertEmbeddingNoImg(subject);
68+
69+
var params = ProcessEmbeddingsParams.builder()
70+
.apiKey(model.getApiKey())
71+
.embeddings(new double[][]{embedding.getEmbedding()})
72+
.additionalParams(Collections.singletonMap(PREDICTION_COUNT, 1))
73+
.build();
74+
75+
when(predictor.predict(any(), any(), anyInt())).thenReturn(List.of(Pair.of(1.0, "subject")));
76+
assertThat(embeddingRepository.findAll()).containsOnly(embedding);
77+
78+
var results = recognizeProcessService.processEmbeddings(params).getResults();
79+
80+
assertThat(embeddingRepository.findAll()).containsOnly(embedding);
81+
assertThat(results).isNotEmpty().hasSize(1);
82+
83+
var result = results.get(0);
84+
85+
assertThat(result.getEmbedding()).isEqualTo(embedding.getEmbedding());
86+
assertThat(result.getResults()).isNotEmpty().hasSize(1);
87+
assertThat(result.getResults().get(0).getSimilarity()).isEqualTo(1.0F);
88+
assertThat(result.getResults().get(0).getSubject()).isEqualTo("subject");
89+
}
90+
91+
@Test
92+
void processEmbeddings_TheInputEmbeddingDoesNotExistInTheDatabase_ShouldNotReturnCompleteSimilarity() {
93+
var model = dbHelper.insertModel();
94+
var subject = dbHelper.insertSubject(model, "subject");
95+
var embedding = dbHelper.insertEmbeddingNoImg(subject);
96+
97+
var params = ProcessEmbeddingsParams.builder()
98+
.apiKey(model.getApiKey())
99+
.embeddings(new double[][]{new double[]{7.3, 8.4, 9.5}})
100+
.additionalParams(Collections.singletonMap(PREDICTION_COUNT, 1))
101+
.build();
102+
103+
when(predictor.predict(any(), any(), anyInt())).thenReturn(List.of(Pair.of(0.0, "subject")));
104+
assertThat(embeddingRepository.findAll()).containsOnly(embedding);
105+
106+
var results = recognizeProcessService.processEmbeddings(params).getResults();
107+
108+
assertThat(embeddingRepository.findAll()).containsOnly(embedding);
109+
assertThat(results).isNotEmpty().hasSize(1);
110+
111+
var result = results.get(0);
112+
113+
assertThat(result.getEmbedding()).isNotEqualTo(embedding.getEmbedding());
114+
assertThat(result.getResults()).isNotEmpty().hasSize(1);
115+
assertThat(result.getResults().get(0).getSimilarity()).isEqualTo(0.0F);
116+
assertThat(result.getResults().get(0).getSubject()).isEqualTo("subject");
117+
}
118+
119+
@ParameterizedTest
120+
@ValueSource(ints = {0, -2})
121+
void processEmbeddings_PredictionCountIsIncorrect_ShouldThrowIncorrectPredictionCountException(int predictionCount) {
122+
var params = ProcessEmbeddingsParams.builder()
123+
.additionalParams(Collections.singletonMap(PREDICTION_COUNT, predictionCount))
124+
.build();
125+
126+
assertThatThrownBy(() -> recognizeProcessService.processEmbeddings(params))
127+
.isInstanceOf(IncorrectPredictionCountException.class);
128+
}
129+
130+
@Test
131+
void processEmbeddings_PredictionCountIsNull_ShouldThrowIncorrectPredictionCountException() {
132+
var params = ProcessEmbeddingsParams.builder()
133+
.additionalParams(Collections.singletonMap(PREDICTION_COUNT, null))
134+
.build();
135+
136+
assertThatThrownBy(() -> recognizeProcessService.processEmbeddings(params))
137+
.isInstanceOf(IncorrectPredictionCountException.class);
138+
}
139+
}
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
package com.exadel.frs.core.trainservice.service;
2+
3+
import static org.assertj.core.api.Assertions.assertThat;
4+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
5+
import static org.mockito.Mockito.when;
6+
import com.exadel.frs.commonservice.exception.WrongEmbeddingCountException;
7+
import com.exadel.frs.commonservice.sdk.faces.FacesApiClient;
8+
import com.exadel.frs.core.trainservice.DbHelper;
9+
import com.exadel.frs.core.trainservice.EmbeddedPostgreSQLTest;
10+
import com.exadel.frs.core.trainservice.component.FaceClassifierPredictor;
11+
import com.exadel.frs.core.trainservice.dto.ProcessEmbeddingsParams;
12+
import com.exadel.frs.core.trainservice.mapper.FacesMapper;
13+
import com.exadel.frs.core.trainservice.repository.AppRepository;
14+
import com.exadel.frs.core.trainservice.validation.ImageExtensionValidator;
15+
import org.apache.commons.lang3.ArrayUtils;
16+
import org.junit.jupiter.api.BeforeEach;
17+
import org.junit.jupiter.api.Test;
18+
import org.springframework.beans.factory.annotation.Autowired;
19+
import org.springframework.boot.test.mock.mockito.MockBean;
20+
import org.springframework.transaction.annotation.Transactional;
21+
22+
@Transactional
23+
class EmbeddingsVerificationProcessServiceImplTest extends EmbeddedPostgreSQLTest {
24+
25+
@Autowired
26+
private DbHelper dbHelper;
27+
28+
@Autowired
29+
private ImageExtensionValidator imageExtensionValidator;
30+
31+
@Autowired
32+
private FacesMapper facesMapper;
33+
34+
@Autowired
35+
private AppRepository appRepository;
36+
37+
@MockBean
38+
private FaceClassifierPredictor predictor;
39+
40+
@MockBean
41+
private FacesApiClient facesApiClient;
42+
43+
@Autowired
44+
private EmbeddingsVerificationProcessServiceImpl verificationProcessService;
45+
46+
@BeforeEach
47+
void cleanUp() {
48+
appRepository.deleteAll();
49+
appRepository.flush();
50+
}
51+
52+
@Test
53+
void processEmbeddings_ThereAreTwoEmbeddingsInTheDatabase_ShouldReturnTwoSimilarityResultInSortedOrder() {
54+
var source = new double[]{1.0, 2.0, 3.0};
55+
var targets = new double[][]{
56+
new double[]{4.0, 5.0, 6.0},
57+
new double[]{7.0, 8.0, 9.0}
58+
};
59+
var similarities = new double[]{0.3, 0.5};
60+
var params = buildParams(source, targets);
61+
62+
when(predictor.verify(source, targets)).thenReturn(similarities);
63+
64+
var results = verificationProcessService.processEmbeddings(params).getResults();
65+
66+
assertThat(results).isNotEmpty().hasSize(2);
67+
68+
var result1 = results.get(0);
69+
var result2 = results.get(1);
70+
71+
assertThat(result1.getSimilarity()).isEqualTo(0.5F);
72+
assertThat(result2.getSimilarity()).isEqualTo(0.3F);
73+
assertThat(result1.getEmbedding()).isEqualTo(targets[1]);
74+
assertThat(result2.getEmbedding()).isEqualTo(targets[0]);
75+
}
76+
77+
@Test
78+
void processEmbeddings_TooFewTargets_ShouldThrowWrongEmbeddingCountException() {
79+
var source = new double[]{1.0, 2.0, 3.0};
80+
var targets = new double[0][];
81+
var params = buildParams(source, targets);
82+
83+
assertThatThrownBy(() -> verificationProcessService.processEmbeddings(params))
84+
.isInstanceOf(WrongEmbeddingCountException.class);
85+
}
86+
87+
@Test
88+
void processEmbeddings_EmbeddingsAreNull_ShouldThrowWrongEmbeddingCountException() {
89+
var params = ProcessEmbeddingsParams.builder().build();
90+
91+
assertThatThrownBy(() -> verificationProcessService.processEmbeddings(params))
92+
.isInstanceOf(WrongEmbeddingCountException.class);
93+
}
94+
95+
private ProcessEmbeddingsParams buildParams(double[] source, double[][] targets) {
96+
return ProcessEmbeddingsParams.builder()
97+
.embeddings(ArrayUtils.insert(0, targets, source))
98+
.build();
99+
}
100+
}

java/api/src/test/java/com/exadel/frs/core/trainservice/service/SubjectServiceTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import static com.exadel.frs.core.trainservice.system.global.Constants.IMAGE_ID;
2222
import static org.assertj.core.api.Assertions.assertThat;
2323
import static org.assertj.core.api.Assertions.assertThatThrownBy;
24+
import static org.junit.jupiter.api.Assertions.assertEquals;
2425
import static org.junit.jupiter.api.Assertions.assertThrows;
2526
import static org.mockito.ArgumentMatchers.any;
2627
import static org.mockito.Mockito.eq;
@@ -34,6 +35,7 @@
3435
import com.exadel.frs.commonservice.entity.Subject;
3536
import com.exadel.frs.commonservice.exception.IncorrectImageIdException;
3637
import com.exadel.frs.commonservice.exception.TooManyFacesException;
38+
import com.exadel.frs.commonservice.exception.WrongEmbeddingCountException;
3739
import com.exadel.frs.commonservice.sdk.faces.FacesApiClient;
3840
import com.exadel.frs.commonservice.sdk.faces.feign.dto.FacesBox;
3941
import com.exadel.frs.commonservice.sdk.faces.feign.dto.FindFacesResponse;
@@ -44,6 +46,8 @@
4446
import com.exadel.frs.core.trainservice.component.FaceClassifierPredictor;
4547
import com.exadel.frs.core.trainservice.component.classifiers.EuclideanDistanceClassifier;
4648
import com.exadel.frs.core.trainservice.dao.SubjectDao;
49+
import com.exadel.frs.core.trainservice.dto.EmbeddingVerificationProcessResult;
50+
import com.exadel.frs.core.trainservice.dto.ProcessEmbeddingsParams;
4751
import com.exadel.frs.core.trainservice.dto.ProcessImageParams;
4852
import java.io.IOException;
4953
import java.util.Map;
@@ -249,6 +253,47 @@ void testVerifyFaces(boolean status) {
249253
}
250254
}
251255

256+
@Test
257+
void verifyEmbedding_ThereAreTwoTargetsAndOneSourceInTheDatabase_ShouldReturnTwoSimilarityResultsInSortedOrder() {
258+
var targets = new double[][]{
259+
new double[]{1, 2, 3},
260+
new double[]{4, 5, 6}
261+
};
262+
var sourceId = UUID.randomUUID();
263+
var apiKey = UUID.randomUUID().toString();
264+
265+
var params = ProcessEmbeddingsParams.builder()
266+
.apiKey(apiKey)
267+
.embeddings(targets)
268+
.additionalParams(Map.of(IMAGE_ID, sourceId))
269+
.build();
270+
271+
when(classifierPredictor.verify(apiKey, targets[0], sourceId)).thenReturn(0.5);
272+
when(classifierPredictor.verify(apiKey, targets[1], sourceId)).thenReturn(1.0);
273+
274+
var results = subjectService.verifyEmbedding(params).getResults();
275+
276+
assertThat(results).isNotEmpty().hasSize(2);
277+
278+
var result1 = results.get(0);
279+
var result2 = results.get(1);
280+
281+
assertThat(result1.getSimilarity()).isEqualTo(1.0F);
282+
assertThat(result2.getSimilarity()).isEqualTo(0.5F);
283+
assertThat(result1.getEmbedding()).isEqualTo(targets[1]);
284+
assertThat(result2.getEmbedding()).isEqualTo(targets[0]);
285+
}
286+
287+
@Test
288+
void verifyEmbedding_ThereAreNoTargets_ShouldThrowWrongEmbeddingCountException() {
289+
var params = ProcessEmbeddingsParams.builder()
290+
.embeddings(new double[][]{})
291+
.build();
292+
293+
assertThatThrownBy(() -> subjectService.verifyEmbedding(params))
294+
.isInstanceOf(WrongEmbeddingCountException.class);
295+
}
296+
252297
@ParameterizedTest
253298
@ValueSource(booleans = {true, false})
254299
void testInvalidImageIdException(boolean status){

0 commit comments

Comments
 (0)