Skip to content

Commit d1d898c

Browse files
Added partition key support to spring data single partition queries (Azure#22483)
* Added partition key support to spring data single partition queries * Spot bug fixes * Check style fixes * Updated logic for IN clause and ignoreCase * Added logging for partition key in spring repository queries
1 parent a4090e1 commit d1d898c

File tree

8 files changed

+175
-22
lines changed

8 files changed

+175
-22
lines changed

sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/integration/AddressRepositoryIT.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,27 @@ public void testFindByPartitionedCity() {
9696
assertThat(result.get(1).getCity()).isEqualTo(city);
9797
}
9898

99+
@Test
100+
public void testFindByPartitionedCityIn() {
101+
final String city = TEST_ADDRESS1_PARTITION1.getCity();
102+
final List<Address> result = TestUtils.toList(repository.findByCityIn(Lists.newArrayList(city)));
103+
104+
assertThat(result.size()).isEqualTo(2);
105+
assertThat(result.get(0).getCity()).isEqualTo(city);
106+
assertThat(result.get(1).getCity()).isEqualTo(city);
107+
}
108+
109+
@Test
110+
public void testFindByPostalCodeAndCityIn() {
111+
final String city = TEST_ADDRESS1_PARTITION1.getCity();
112+
final List<String> postalCodes = Lists.newArrayList(TEST_ADDRESS1_PARTITION1.getPostalCode(),
113+
TEST_ADDRESS2_PARTITION1.getPostalCode());
114+
final List<Address> result = TestUtils.toList(repository.findByPostalCodeInAndCity(postalCodes, city));
115+
116+
assertThat(result.size()).isEqualTo(2);
117+
assertThat(result).isEqualTo(Lists.newArrayList(TEST_ADDRESS1_PARTITION1, TEST_ADDRESS2_PARTITION1));
118+
}
119+
99120
@Test
100121
public void testFindByStreetOrCity() {
101122
final String city = TEST_ADDRESS1_PARTITION1.getCity();

sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/integration/ProjectRepositoryIT.java

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,10 @@ public void testFindByIn() {
389389
projects = TestUtils.toList(repository.findByCreatorIn(Arrays.asList(CREATOR_0, FAKE_CREATOR)));
390390

391391
assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_4));
392+
393+
projects = TestUtils.toList(repository.findByCreatorIn(Collections.singletonList(CREATOR_1)));
394+
395+
assertProjectListEquals(projects, Collections.singletonList(PROJECT_1));
392396
}
393397

394398
@Test
@@ -411,6 +415,24 @@ public void testFindByInWithAnd() {
411415
assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_1, PROJECT_2, PROJECT_4));
412416
}
413417

418+
@Test
419+
public void testFindByInWithOr() {
420+
List<Project> projects = TestUtils.toList(repository.findByCreatorInOrStarCount(Arrays.asList(CREATOR_0,
421+
CREATOR_1), STAR_COUNT_2));
422+
423+
assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_4, PROJECT_1, PROJECT_2));
424+
425+
projects = TestUtils.toList(repository.findByCreatorInOrStarCount(Collections.singletonList(CREATOR_1),
426+
STAR_COUNT_2));
427+
428+
assertProjectListEquals(projects, Arrays.asList(PROJECT_1, PROJECT_2));
429+
430+
projects = TestUtils.toList(repository.findByCreatorInOrStarCount(Collections.singletonList(CREATOR_0),
431+
STAR_COUNT_0));
432+
433+
assertProjectListEquals(projects, Arrays.asList(PROJECT_0, PROJECT_4));
434+
}
435+
414436
@Test
415437
public void testFindByNotIn() {
416438
List<Project> projects = TestUtils.toList(repository.findByCreatorNotIn(

sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/repository/AddressRepository.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,12 @@ public interface AddressRepository extends CosmosRepository<Address, String> {
2424

2525
Iterable<Address> findByCity(String city);
2626

27+
Iterable<Address> findByCityIn(List<String> cities);
28+
2729
Iterable<Address> findByPostalCode(String postalCode);
2830

31+
Iterable<Address> findByPostalCodeInAndCity(List<String> postalCodes, String city);
32+
2933
Iterable<Address> findByStreetOrCity(String street, String city);
3034

3135
@Query("select * from a where a.city = @city")

sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/repository/ProjectRepository.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ Iterable<Project> findByNameOrCreatorAndForkCountOrStarCount(String name, String
7171

7272
Iterable<Project> findByCreatorInAndStarCountIn(Collection<String> creators, Collection<Long> starCounts);
7373

74+
Iterable<Project> findByCreatorInOrStarCount(Collection<String> creators, Long starCount);
75+
7476
Iterable<Project> findByCreatorNotIn(Collection<String> creators);
7577

7678
Iterable<Project> findByCreatorInAndStarCountNotIn(Collection<String> creators, Collection<Long> starCounts);

sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/CosmosTemplate.java

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
import java.util.Collections;
5656
import java.util.Iterator;
5757
import java.util.List;
58+
import java.util.Optional;
5859
import java.util.UUID;
5960
import java.util.stream.Collectors;
6061

@@ -632,7 +633,8 @@ public <T> Iterable<T> delete(@NonNull CosmosQuery query, @NonNull Class<T> doma
632633
Assert.notNull(domainType, "domainType should not be null.");
633634
Assert.hasText(containerName, "container should not be null, empty or only whitespaces");
634635

635-
final List<JsonNode> results = findItems(query, containerName);
636+
final List<JsonNode> results = findItemsAsFlux(query, containerName, domainType).collectList().block();
637+
assert results != null;
636638
return results.stream()
637639
.map(item -> deleteItem(item, containerName, domainType))
638640
.collect(Collectors.toList());
@@ -654,39 +656,49 @@ public <T> Page<T> runPaginationQuery(SqlQuerySpec querySpec, Pageable pageable,
654656
final String containerName = getContainerName(domainType);
655657
final SqlQuerySpec sortedQuerySpec = NativeQueryGenerator.getInstance().generateSortedQuery(querySpec, pageable.getSort());
656658
final SqlQuerySpec countQuerySpec = NativeQueryGenerator.getInstance().generateCountQuery(querySpec);
657-
return paginationQuery(sortedQuerySpec, countQuerySpec, pageable, pageable.getSort(), returnType, containerName);
659+
return paginationQuery(sortedQuerySpec, countQuerySpec, pageable,
660+
pageable.getSort(), returnType, containerName, Optional.empty());
658661
}
659662

660663
@Override
661664
public <T> Page<T> paginationQuery(CosmosQuery query, Class<T> domainType, String containerName) {
662665
final SqlQuerySpec querySpec = new FindQuerySpecGenerator().generateCosmos(query);
663666
final SqlQuerySpec countQuerySpec = new CountQueryGenerator().generateCosmos(query);
664-
return paginationQuery(querySpec, countQuerySpec, query.getPageable(), query.getSort(), domainType, containerName);
667+
Optional<Object> partitionKeyValue = query.getPartitionKeyValue(domainType);
668+
return paginationQuery(querySpec, countQuerySpec, query.getPageable(),
669+
query.getSort(), domainType, containerName, partitionKeyValue);
665670
}
666671

667672
@Override
668673
public <T> Slice<T> sliceQuery(CosmosQuery query, Class<T> domainType, String containerName) {
669674
final SqlQuerySpec querySpec = new FindQuerySpecGenerator().generateCosmos(query);
670-
return sliceQuery(querySpec, query.getPageable(), query.getSort(), domainType, containerName);
675+
Optional<Object> partitionKeyValue = query.getPartitionKeyValue(domainType);
676+
return sliceQuery(querySpec, query.getPageable(), query.getSort(), domainType, containerName, partitionKeyValue);
671677
}
672678

673679
private <T> Page<T> paginationQuery(SqlQuerySpec querySpec, SqlQuerySpec countQuerySpec,
674680
Pageable pageable, Sort sort,
675-
Class<T> returnType, String containerName) {
676-
Slice<T> response = sliceQuery(querySpec, pageable, sort, returnType, containerName);
681+
Class<T> returnType, String containerName,
682+
Optional<Object> partitionKeyValue) {
683+
Slice<T> response = sliceQuery(querySpec, pageable, sort, returnType, containerName, partitionKeyValue);
677684
final long total = getCountValue(countQuerySpec, containerName);
678685
return new CosmosPageImpl<>(response.getContent(), response.getPageable(), total);
679686
}
680687

681688
private <T> Slice<T> sliceQuery(SqlQuerySpec querySpec,
682689
Pageable pageable, Sort sort,
683-
Class<T> returnType, String containerName) {
690+
Class<T> returnType, String containerName,
691+
Optional<Object> partitionKeyValue) {
684692
Assert.isTrue(pageable.getPageSize() > 0,
685693
"pageable should have page size larger than 0");
686694
Assert.hasText(containerName, "container should not be null, empty or only whitespaces");
687695

688696
final CosmosQueryRequestOptions cosmosQueryRequestOptions = new CosmosQueryRequestOptions();
689697
cosmosQueryRequestOptions.setQueryMetricsEnabled(this.queryMetricsEnabled);
698+
partitionKeyValue.ifPresent(o -> {
699+
LOGGER.debug("Setting partition key {}", o);
700+
cosmosQueryRequestOptions.setPartitionKey(new PartitionKey(o));
701+
});
690702

691703
CosmosAsyncContainer container =
692704
cosmosAsyncClient.getDatabase(this.databaseName).getContainer(containerName);
@@ -836,11 +848,17 @@ private Flux<FeedResponse<JsonNode>> executeQuery(SqlQuerySpec sqlQuerySpec,
836848
.byPage();
837849
}
838850

839-
private Flux<JsonNode> findItemsAsFlux(@NonNull CosmosQuery query,
840-
@NonNull String containerName) {
851+
private <T> Flux<JsonNode> findItemsAsFlux(@NonNull CosmosQuery query,
852+
@NonNull String containerName,
853+
@NonNull Class<T> domainType) {
841854
final SqlQuerySpec sqlQuerySpec = new FindQuerySpecGenerator().generateCosmos(query);
842855
final CosmosQueryRequestOptions cosmosQueryRequestOptions = new CosmosQueryRequestOptions();
843856
cosmosQueryRequestOptions.setQueryMetricsEnabled(this.queryMetricsEnabled);
857+
Optional<Object> partitionKeyValue = query.getPartitionKeyValue(domainType);
858+
partitionKeyValue.ifPresent(o -> {
859+
LOGGER.debug("Setting partition key {}", o);
860+
cosmosQueryRequestOptions.setPartitionKey(new PartitionKey(o));
861+
});
844862

845863
return cosmosAsyncClient
846864
.getDatabase(this.databaseName)
@@ -879,17 +897,10 @@ private Flux<JsonNode> getJsonNodeFluxFromQuerySpec(
879897
CosmosExceptionUtils.exceptionHandler("Failed to find items", throwable));
880898
}
881899

882-
private List<JsonNode> findItems(@NonNull CosmosQuery query,
883-
@NonNull String containerName) {
884-
return findItemsAsFlux(query, containerName)
885-
.collectList()
886-
.block();
887-
}
888-
889900
private <T> Iterable<T> findItems(@NonNull CosmosQuery query,
890901
@NonNull String containerName,
891902
@NonNull Class<T> domainType) {
892-
return findItemsAsFlux(query, containerName)
903+
return findItemsAsFlux(query, containerName, domainType)
893904
.map(jsonNode -> toDomainObject(domainType, jsonNode))
894905
.toIterable();
895906
}

sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/ReactiveCosmosTemplate.java

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
import com.azure.spring.data.cosmos.exception.CosmosExceptionUtils;
2929
import com.azure.spring.data.cosmos.repository.support.CosmosEntityInformation;
3030
import com.fasterxml.jackson.databind.JsonNode;
31+
import org.slf4j.Logger;
32+
import org.slf4j.LoggerFactory;
3133
import org.springframework.beans.BeansException;
3234
import org.springframework.context.ApplicationContext;
3335
import org.springframework.context.ApplicationContextAware;
@@ -40,6 +42,7 @@
4042
import reactor.core.publisher.Mono;
4143
import reactor.core.scheduler.Schedulers;
4244

45+
import java.util.Optional;
4346
import java.util.UUID;
4447

4548
/**
@@ -48,6 +51,8 @@
4851
@SuppressWarnings("unchecked")
4952
public class ReactiveCosmosTemplate implements ReactiveCosmosOperations, ApplicationContextAware {
5053

54+
private static final Logger LOGGER = LoggerFactory.getLogger(ReactiveCosmosTemplate.class);
55+
5156
private final MappingCosmosConverter mappingCosmosConverter;
5257
private final String databaseName;
5358
private final ResponseDiagnosticsProcessor responseDiagnosticsProcessor;
@@ -529,7 +534,7 @@ public <T> Flux<T> delete(CosmosQuery query, Class<T> domainType, String contain
529534
Assert.notNull(domainType, "domainType should not be null.");
530535
Assert.hasText(containerName, "container name should not be null, empty or only whitespaces");
531536

532-
final Flux<JsonNode> results = findItems(query, containerName);
537+
final Flux<JsonNode> results = findItems(query, containerName, domainType);
533538

534539
return results.flatMap(d -> deleteItem(d, containerName, domainType));
535540
}
@@ -544,7 +549,7 @@ public <T> Flux<T> delete(CosmosQuery query, Class<T> domainType, String contain
544549
*/
545550
@Override
546551
public <T> Flux<T> find(CosmosQuery query, Class<T> domainType, String containerName) {
547-
return findItems(query, containerName)
552+
return findItems(query, containerName, domainType)
548553
.map(cosmosItemProperties -> toDomainObject(domainType, cosmosItemProperties));
549554
}
550555

@@ -710,11 +715,17 @@ private void markAuditedIfConfigured(Object object) {
710715
}
711716
}
712717

713-
private Flux<JsonNode> findItems(@NonNull CosmosQuery query,
714-
@NonNull String containerName) {
718+
private <T> Flux<JsonNode> findItems(@NonNull CosmosQuery query,
719+
@NonNull String containerName,
720+
@NonNull Class<T> domainType) {
715721
final SqlQuerySpec sqlQuerySpec = new FindQuerySpecGenerator().generateCosmos(query);
716722
final CosmosQueryRequestOptions cosmosQueryRequestOptions = new CosmosQueryRequestOptions();
717723
cosmosQueryRequestOptions.setQueryMetricsEnabled(this.queryMetricsEnabled);
724+
Optional<Object> partitionKeyValue = query.getPartitionKeyValue(domainType);
725+
partitionKeyValue.ifPresent(o -> {
726+
LOGGER.debug("Setting partition key {}", o);
727+
cosmosQueryRequestOptions.setPartitionKey(new PartitionKey(o));
728+
});
718729

719730
return cosmosAsyncClient
720731
.getDatabase(this.databaseName)

sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/core/query/CosmosQuery.java

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,15 @@
22
// Licensed under the MIT License.
33
package com.azure.spring.data.cosmos.core.query;
44

5+
import com.azure.spring.data.cosmos.repository.support.CosmosEntityInformation;
56
import org.springframework.data.domain.Pageable;
67
import org.springframework.data.domain.Sort;
8+
import org.springframework.data.repository.query.parser.Part;
79
import org.springframework.lang.NonNull;
810
import org.springframework.util.Assert;
911

12+
import java.util.Collection;
13+
import java.util.Collections;
1014
import java.util.List;
1115
import java.util.Optional;
1216

@@ -111,14 +115,37 @@ private boolean isCrossPartitionQuery(@NonNull String keyName) {
111115

112116
final Optional<Criteria> criteria = this.getSubjectCriteria(this.criteria, keyName);
113117

114-
return criteria.map(criteria1 -> criteria1.getType() != CriteriaType.IS_EQUAL).orElse(true);
118+
return criteria.map(criteria1 -> {
119+
// If there is equal criteria, then it is a single partition query
120+
if (isEqualCriteria(criteria1)) {
121+
return false;
122+
}
123+
// IN is a special case, where we want to first check if the partition key is used with IN clause
124+
if (criteria1.getType() == CriteriaType.IN && criteria1.getSubjectValues().size() == 1) {
125+
@SuppressWarnings("unchecked")
126+
Collection<Object> collection = (Collection<Object>) criteria1.getSubjectValues().get(0);
127+
// IN query types can have multiple values,
128+
// so we are checking the internal collection of the criteria
129+
return collection.size() != 1;
130+
}
131+
return !hasKeywordAnd();
132+
}).orElse(true);
115133
}
116134

117135
private boolean hasKeywordOr() {
118136
// If there is OR keyword in DocumentQuery, the top node of Criteria must be OR type.
119137
return this.criteria.getType() == CriteriaType.OR;
120138
}
121139

140+
private boolean hasKeywordAnd() {
141+
// If there is AND keyword in DocumentQuery, the top node of Criteria must be AND type.
142+
return this.criteria.getType() == CriteriaType.AND;
143+
}
144+
145+
private boolean isEqualCriteria(Criteria criteria) {
146+
return criteria.getType() == CriteriaType.IS_EQUAL;
147+
}
148+
122149
/**
123150
* Indicate if DocumentQuery should enable cross partition query.
124151
*
@@ -136,6 +163,57 @@ public boolean isCrossPartitionQuery(@NonNull List<String> partitionKeys) {
136163
.orElse(hasKeywordOr());
137164
}
138165

166+
/**
167+
* Returns true if this criteria or sub-criteria has partition key field present as one of the subjects.
168+
* @param partitionKeyFieldName partition key field name
169+
* @return returns true if this criteria or sub criteria has partition key field present as one of the subjects.
170+
*/
171+
public boolean hasPartitionKeyCriteria(@NonNull String partitionKeyFieldName) {
172+
if (partitionKeyFieldName.isEmpty()) {
173+
return false;
174+
}
175+
176+
final Optional<Criteria> criteria = this.getSubjectCriteria(this.criteria, partitionKeyFieldName);
177+
return criteria.isPresent();
178+
}
179+
180+
/**
181+
* Returns partition key value based on the criteria.
182+
* @param domainType domain type
183+
* @param <T> entity class type
184+
* @return Optional of partition key value
185+
*/
186+
public <T> Optional<Object> getPartitionKeyValue(@NonNull Class<T> domainType) {
187+
CosmosEntityInformation<?, ?> instance = CosmosEntityInformation.getInstance(domainType);
188+
String partitionKeyFieldName = instance.getPartitionKeyFieldName();
189+
if (partitionKeyFieldName == null
190+
|| partitionKeyFieldName.isEmpty()
191+
|| isCrossPartitionQuery(Collections.singletonList(partitionKeyFieldName))) {
192+
return Optional.empty();
193+
}
194+
195+
final Optional<Criteria> criteria = this.getSubjectCriteria(this.criteria, partitionKeyFieldName);
196+
return criteria.map(criteria1 -> {
197+
// If the criteria has ignoreCase, then we cannot set the partition key
198+
// because of case sensitivity of partition key
199+
if (!criteria1.getIgnoreCase().equals(Part.IgnoreCaseType.NEVER)) {
200+
return null;
201+
}
202+
if (criteria1.getType() == CriteriaType.IN && criteria1.getSubjectValues().size() == 1) {
203+
@SuppressWarnings("unchecked")
204+
Collection<Object> collection = (Collection<Object>) criteria1.getSubjectValues().get(0);
205+
// IN query types can have multiple values,
206+
// so we are checking the internal collection of the criteria
207+
if (collection.size() == 1) {
208+
return collection.iterator().next();
209+
} else {
210+
return null;
211+
}
212+
}
213+
return criteria1.getSubjectValues().get(0);
214+
});
215+
}
216+
139217
/**
140218
* To get criteria by type
141219
*

sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/repository/support/CosmosEntityInformation.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,10 @@ public Object getPartitionKeyFieldValue(T entity) {
252252
return partitionKeyField == null ? null : ReflectionUtils.getField(partitionKeyField, entity);
253253
}
254254

255+
public String getPartitionKeyFieldName() {
256+
return partitionKeyField == null ? null : partitionKeyField.getName();
257+
}
258+
255259
/**
256260
* Check if auto creating container is allowed
257261
*

0 commit comments

Comments
 (0)