Skip to content

Commit c0d9bde

Browse files
authored
Enhancing CosmosTemplate to Support Multi-Tenancy at a DB Level (Azure#32516)
* Proof of concept that we can write to two databases from the same session. * Improving the changes to CosmosTemplate and the test case. * Moving default setNameAndCreateDatabase() logic into CosmosTemplate. * Improving unit test. * Changing function name to be a more accurate description of the functionality. * Updating changelog * Removing unused imports. * Code cleanup. * Refactoring CosmosTemplate to now store the CosmosFactory on the template. With this updated CosmosFactory so that it can be extended to achieve Multi-Tenancy at the database level. The test case was updated also. * Updating changelog. * Making the requested updates in the PR. Adding CosmosFactory to ReactiveCosmosTemplate and adding sample to ReadMe. * Making updates for PR comments. * Fixing updates to unit test. * Fixing readme * Adding file needed for readme. * Fixing snippet for readme. * Fixing snippet for readme. * Updating readme. * Adding javadoc. * Fixing unit test. * Testing. * Testing breaking out setup to be before unit test runs. * Renaming file. * Adding new test config for MultiTenantDB test. * Adding cleanup to unit test.
1 parent b04b9ee commit c0d9bde

File tree

10 files changed

+405
-64
lines changed

10 files changed

+405
-64
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.azure.spring.data.cosmos.core;
5+
6+
import com.azure.cosmos.CosmosAsyncClient;
7+
import com.azure.spring.data.cosmos.CosmosFactory;
8+
9+
/**
10+
* Example for extending CosmosFactory for Mutli-Tenancy at the database level
11+
*/
12+
public class MultiTenantDBCosmosFactory extends CosmosFactory {
13+
14+
public String manuallySetDatabaseName;
15+
16+
/**
17+
* Validate config and initialization
18+
*
19+
* @param cosmosAsyncClient cosmosAsyncClient
20+
* @param databaseName databaseName
21+
*/
22+
public MultiTenantDBCosmosFactory(CosmosAsyncClient cosmosAsyncClient, String databaseName) {
23+
super(cosmosAsyncClient, databaseName);
24+
25+
this.manuallySetDatabaseName = databaseName;
26+
}
27+
28+
@Override
29+
public String getDatabaseName() {
30+
return this.manuallySetDatabaseName;
31+
}
32+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
package com.azure.spring.data.cosmos.core;
5+
6+
import com.azure.cosmos.CosmosAsyncClient;
7+
import com.azure.cosmos.CosmosAsyncDatabase;
8+
import com.azure.cosmos.CosmosClientBuilder;
9+
import com.azure.cosmos.CosmosException;
10+
import com.azure.cosmos.models.PartitionKey;
11+
import com.azure.spring.data.cosmos.CosmosFactory;
12+
import com.azure.spring.data.cosmos.IntegrationTestCollectionManager;
13+
import com.azure.spring.data.cosmos.config.CosmosConfig;
14+
import com.azure.spring.data.cosmos.core.convert.MappingCosmosConverter;
15+
import com.azure.spring.data.cosmos.core.mapping.CosmosMappingContext;
16+
import com.azure.spring.data.cosmos.domain.Person;
17+
import com.azure.spring.data.cosmos.repository.MultiTenantTestRepositoryConfig;
18+
import com.azure.spring.data.cosmos.repository.support.CosmosEntityInformation;
19+
import org.junit.Assert;
20+
import org.junit.Before;
21+
import org.junit.ClassRule;
22+
import org.junit.Test;
23+
import org.junit.runner.RunWith;
24+
import org.springframework.beans.factory.annotation.Autowired;
25+
import org.springframework.boot.autoconfigure.domain.EntityScanner;
26+
import org.springframework.context.ApplicationContext;
27+
import org.springframework.data.annotation.Persistent;
28+
import org.springframework.test.context.ContextConfiguration;
29+
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
30+
31+
import java.util.ArrayList;
32+
import java.util.List;
33+
34+
import static com.azure.spring.data.cosmos.common.TestConstants.ADDRESSES;
35+
import static com.azure.spring.data.cosmos.common.TestConstants.AGE;
36+
import static com.azure.spring.data.cosmos.common.TestConstants.FIRST_NAME;
37+
import static com.azure.spring.data.cosmos.common.TestConstants.HOBBIES;
38+
import static com.azure.spring.data.cosmos.common.TestConstants.ID_1;
39+
import static com.azure.spring.data.cosmos.common.TestConstants.ID_2;
40+
import static com.azure.spring.data.cosmos.common.TestConstants.LAST_NAME;
41+
import static com.azure.spring.data.cosmos.common.TestConstants.PASSPORT_IDS_BY_COUNTRY;
42+
import static org.assertj.core.api.Assertions.assertThat;
43+
import static org.junit.Assert.assertEquals;
44+
45+
@RunWith(SpringJUnit4ClassRunner.class)
46+
@ContextConfiguration(classes = MultiTenantTestRepositoryConfig.class)
47+
public class MultiTenantDBCosmosFactoryIT {
48+
49+
private final String testDB1 = "Database1";
50+
private final String testDB2 = "Database2";
51+
52+
private final Person TEST_PERSON_1 = new Person(ID_1, FIRST_NAME, LAST_NAME, HOBBIES, ADDRESSES, AGE, PASSPORT_IDS_BY_COUNTRY);
53+
private final Person TEST_PERSON_2 = new Person(ID_2, FIRST_NAME, LAST_NAME, HOBBIES, ADDRESSES, AGE, PASSPORT_IDS_BY_COUNTRY);
54+
55+
@ClassRule
56+
public static final IntegrationTestCollectionManager collectionManager = new IntegrationTestCollectionManager();
57+
58+
@Autowired
59+
private ApplicationContext applicationContext;
60+
@Autowired
61+
private CosmosConfig cosmosConfig;
62+
@Autowired
63+
private CosmosClientBuilder cosmosClientBuilder;
64+
65+
private MultiTenantDBCosmosFactory cosmosFactory;
66+
private CosmosTemplate cosmosTemplate;
67+
private CosmosAsyncClient client;
68+
private CosmosEntityInformation<Person, String> personInfo;
69+
70+
@Before
71+
public void setUp() throws ClassNotFoundException {
72+
/// Setup
73+
client = CosmosFactory.createCosmosAsyncClient(cosmosClientBuilder);
74+
cosmosFactory = new MultiTenantDBCosmosFactory(client, testDB1);
75+
final CosmosMappingContext mappingContext = new CosmosMappingContext();
76+
77+
try {
78+
mappingContext.setInitialEntitySet(new EntityScanner(this.applicationContext).scan(Persistent.class));
79+
} catch (Exception e) {
80+
Assert.fail();
81+
}
82+
83+
final MappingCosmosConverter cosmosConverter = new MappingCosmosConverter(mappingContext, null);
84+
cosmosTemplate = new CosmosTemplate(cosmosFactory, cosmosConfig, cosmosConverter, null);
85+
personInfo = new CosmosEntityInformation<>(Person.class);
86+
}
87+
88+
@Test
89+
public void testGetDatabaseFunctionality() {
90+
// Create DB1 and add TEST_PERSON_1 to it
91+
cosmosTemplate.createContainerIfNotExists(personInfo);
92+
cosmosTemplate.deleteAll(personInfo.getContainerName(), Person.class);
93+
assertThat(cosmosFactory.getDatabaseName()).isEqualTo(testDB1);
94+
cosmosTemplate.insert(TEST_PERSON_1, new PartitionKey(personInfo.getPartitionKeyFieldValue(TEST_PERSON_1)));
95+
96+
// Create DB2 and add TEST_PERSON_2 to it
97+
cosmosFactory.manuallySetDatabaseName = testDB2;
98+
cosmosTemplate.createContainerIfNotExists(personInfo);
99+
cosmosTemplate.deleteAll(personInfo.getContainerName(), Person.class);
100+
assertThat(cosmosFactory.getDatabaseName()).isEqualTo(testDB2);
101+
cosmosTemplate.insert(TEST_PERSON_2, new PartitionKey(personInfo.getPartitionKeyFieldValue(TEST_PERSON_2)));
102+
103+
// Check that DB2 has the correct contents
104+
List<Person> expectedResultsDB2 = new ArrayList<>();
105+
expectedResultsDB2.add(TEST_PERSON_2);
106+
Iterable<Person> iterableDB2 = cosmosTemplate.findAll(personInfo.getContainerName(), Person.class);
107+
List<Person> resultDB2 = new ArrayList<>();
108+
iterableDB2.forEach(resultDB2::add);
109+
Assert.assertEquals(expectedResultsDB2, resultDB2);
110+
111+
// Check that DB1 has the correct contents
112+
cosmosFactory.manuallySetDatabaseName = testDB1;
113+
List<Person> expectedResultsDB1 = new ArrayList<>();
114+
expectedResultsDB1.add(TEST_PERSON_1);
115+
Iterable<Person> iterableDB1 = cosmosTemplate.findAll(personInfo.getContainerName(), Person.class);
116+
List<Person> resultDB1 = new ArrayList<>();
117+
iterableDB1.forEach(resultDB1::add);
118+
Assert.assertEquals(expectedResultsDB1, resultDB1);
119+
120+
//Cleanup
121+
deleteDatabaseIfExists(testDB1);
122+
deleteDatabaseIfExists(testDB2);
123+
}
124+
125+
private void deleteDatabaseIfExists(String dbName) {
126+
CosmosAsyncDatabase database = client.getDatabase(dbName);
127+
try {
128+
database.delete().block();
129+
} catch (CosmosException ex) {
130+
assertEquals(ex.getStatusCode(), 404);
131+
}
132+
}
133+
}

sdk/cosmos/azure-spring-data-cosmos-test/src/test/java/com/azure/spring/data/cosmos/repository/MultiCosmosTemplateIT.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import com.azure.cosmos.CosmosAsyncClient;
66
import com.azure.cosmos.models.PartitionKey;
7+
import com.azure.spring.data.cosmos.CosmosFactory;
78
import com.azure.spring.data.cosmos.ReactiveIntegrationTestCollectionManager;
89
import com.azure.spring.data.cosmos.common.TestConstants;
910
import com.azure.spring.data.cosmos.core.ReactiveCosmosTemplate;
@@ -80,10 +81,12 @@ public void testSecondaryTemplateWithDiffDatabase() {
8081

8182
@Test
8283
public void testSingleCosmosClientForMultipleCosmosTemplate() throws IllegalAccessException {
83-
final Field cosmosAsyncClient = FieldUtils.getDeclaredField(ReactiveCosmosTemplate.class,
84-
"cosmosAsyncClient", true);
85-
CosmosAsyncClient client1 = (CosmosAsyncClient) cosmosAsyncClient.get(secondaryReactiveCosmosTemplate);
86-
CosmosAsyncClient client2 = (CosmosAsyncClient) cosmosAsyncClient.get(secondaryDiffDatabaseReactiveCosmosTemplate);
84+
final Field cosmosFactory = FieldUtils.getDeclaredField(ReactiveCosmosTemplate.class,
85+
"cosmosFactory", true);
86+
CosmosFactory cosmosFactory1 = (CosmosFactory) cosmosFactory.get(secondaryReactiveCosmosTemplate);
87+
CosmosAsyncClient client1 = cosmosFactory1.getCosmosAsyncClient();
88+
CosmosFactory cosmosFactory2 = (CosmosFactory) cosmosFactory.get(secondaryDiffDatabaseReactiveCosmosTemplate);
89+
CosmosAsyncClient client2 = cosmosFactory2.getCosmosAsyncClient();
8790
Assertions.assertThat(client1).isEqualTo(client2);
8891
}
8992
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
package com.azure.spring.data.cosmos.repository;
4+
5+
import com.azure.cosmos.CosmosAsyncClient;
6+
import com.azure.cosmos.CosmosClientBuilder;
7+
import com.azure.spring.data.cosmos.common.ResponseDiagnosticsTestUtils;
8+
import com.azure.spring.data.cosmos.common.TestConstants;
9+
import com.azure.spring.data.cosmos.config.AbstractCosmosConfiguration;
10+
import com.azure.spring.data.cosmos.config.CosmosConfig;
11+
import com.azure.spring.data.cosmos.core.MultiTenantDBCosmosFactory;
12+
import com.azure.spring.data.cosmos.core.mapping.event.SimpleCosmosMappingEventListener;
13+
import com.azure.spring.data.cosmos.repository.config.EnableCosmosRepositories;
14+
import com.azure.spring.data.cosmos.repository.config.EnableReactiveCosmosRepositories;
15+
import org.springframework.beans.factory.annotation.Value;
16+
import org.springframework.context.annotation.Bean;
17+
import org.springframework.context.annotation.Configuration;
18+
import org.springframework.context.annotation.PropertySource;
19+
import org.springframework.util.StringUtils;
20+
21+
import java.util.Arrays;
22+
import java.util.Collection;
23+
24+
@Configuration
25+
@PropertySource(value = { "classpath:application.properties" })
26+
@EnableCosmosRepositories
27+
@EnableReactiveCosmosRepositories
28+
public class MultiTenantTestRepositoryConfig extends AbstractCosmosConfiguration {
29+
@Value("${cosmos.uri:}")
30+
private String cosmosDbUri;
31+
32+
@Value("${cosmos.key:}")
33+
private String cosmosDbKey;
34+
35+
@Value("${cosmos.database:}")
36+
private String database;
37+
38+
@Value("${cosmos.queryMetricsEnabled}")
39+
private boolean queryMetricsEnabled;
40+
41+
@Value("${cosmos.maxDegreeOfParallelism}")
42+
private int maxDegreeOfParallelism;
43+
44+
@Value("${cosmos.maxBufferedItemCount}")
45+
private int maxBufferedItemCount;
46+
47+
@Value("${cosmos.responseContinuationTokenLimitInKb}")
48+
private int responseContinuationTokenLimitInKb;
49+
50+
@Bean
51+
public ResponseDiagnosticsTestUtils responseDiagnosticsTestUtils() {
52+
return new ResponseDiagnosticsTestUtils();
53+
}
54+
55+
@Bean
56+
public CosmosClientBuilder cosmosClientBuilder() {
57+
return new CosmosClientBuilder()
58+
.key(cosmosDbKey)
59+
.endpoint(cosmosDbUri)
60+
.contentResponseOnWriteEnabled(true);
61+
}
62+
63+
@Bean
64+
public MultiTenantDBCosmosFactory cosmosFactory(CosmosAsyncClient cosmosAsyncClient) {
65+
return new MultiTenantDBCosmosFactory(cosmosAsyncClient, getDatabaseName());
66+
}
67+
68+
@Bean
69+
@Override
70+
public CosmosConfig cosmosConfig() {
71+
return CosmosConfig.builder()
72+
.enableQueryMetrics(queryMetricsEnabled)
73+
.maxDegreeOfParallelism(maxDegreeOfParallelism)
74+
.maxBufferedItemCount(maxBufferedItemCount)
75+
.responseContinuationTokenLimitInKb(responseContinuationTokenLimitInKb)
76+
.responseDiagnosticsProcessor(responseDiagnosticsTestUtils().getResponseDiagnosticsProcessor())
77+
.build();
78+
}
79+
80+
@Override
81+
protected String getDatabaseName() {
82+
return StringUtils.hasText(this.database) ? this.database : TestConstants.DB_NAME;
83+
}
84+
85+
@Override
86+
protected Collection<String> getMappingBasePackages() {
87+
final Package mappingBasePackage = getClass().getPackage();
88+
final String entityPackage = "com.azure.spring.data.cosmos.domain";
89+
return Arrays.asList(mappingBasePackage.getName(), entityPackage);
90+
}
91+
92+
@Bean
93+
SimpleCosmosMappingEventListener simpleMappingEventListener() {
94+
return new SimpleCosmosMappingEventListener();
95+
}
96+
}

sdk/cosmos/azure-spring-data-cosmos/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
### 3.31.0-beta.1 (Unreleased)
44

55
#### Features Added
6+
* Added support for multi-tenancy at the Database level via `CosmosFactory` - See [PR 32516](https://github.com/Azure/azure-sdk-for-java/pull/32516)
67

78
#### Breaking Changes
89

sdk/cosmos/azure-spring-data-cosmos/README.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,32 @@ public class MultiDatabaseApplication implements CommandLineRunner {
934934
}
935935
```
936936

937+
### Multi-Tenancy at the Database Level
938+
- Azure-spring-data-cosmos supports multi-tenancy at the database level configuration by extending `CosmosFactory` and overriding the getDatabaseName() function.
939+
```java readme-sample-MultiTenantDBCosmosFactory
940+
public class MultiTenantDBCosmosFactory extends CosmosFactory {
941+
942+
private String tenantId;
943+
944+
/**
945+
* Validate config and initialization
946+
*
947+
* @param cosmosAsyncClient cosmosAsyncClient
948+
* @param databaseName databaseName
949+
*/
950+
public MultiTenantDBCosmosFactory(CosmosAsyncClient cosmosAsyncClient, String databaseName) {
951+
super(cosmosAsyncClient, databaseName);
952+
953+
this.tenantId = databaseName;
954+
}
955+
956+
@Override
957+
public String getDatabaseName() {
958+
return this.getCosmosAsyncClient().getDatabase(this.tenantId).toString();
959+
}
960+
}
961+
```
962+
937963
## Beta version package
938964

939965
Beta version built from `main` branch are available, you can refer to the [instruction](https://github.com/Azure/azure-sdk-for-java/blob/main/CONTRIBUTING.md#nightly-package-builds) to use beta version packages.

sdk/cosmos/azure-spring-data-cosmos/src/main/java/com/azure/spring/data/cosmos/CosmosFactory.java

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@ public class CosmosFactory {
2020

2121
private final CosmosAsyncClient cosmosAsyncClient;
2222

23-
private final String databaseName;
23+
/**
24+
* Database Name to be used for operations.
25+
*/
26+
protected String databaseName;
2427

2528
private static final String USER_AGENT_SUFFIX =
2629
Constants.USER_AGENT_SUFFIX + PropertyLoader.getProjectVersion();

0 commit comments

Comments
 (0)