Skip to content

Commit 85d9988

Browse files
ServicePrincipal with client certificate auth for azure-cosmos-spark (Azure#40325)
* ServicePrincipal with client certificate auth for azure-cosmos-spark * Update spark.yml * Fixing test failure * Update CosmosConfig.scala * Fixing test failures * Update AccountTokenResolverSample.ipynb * Update basicScenarioAadCert.scala * Update databricks-notebooks-install.sh * Changelog * Adding option to include sendCertificateChain * Update AccountTokenResolverSample.ipynb * Fixing Spark live tests * Update basicScenarioAadCert.scala * Update CosmosConfigSpec.scala * Fixing Spark live tests * Adding logs in notebooks * Update basicScenarioAadCert.scala * NOtebook adjustments * Fixing notebooks * Update databricks-notebooks-install.sh * Create basicScenarioAadMisspelled.scala * Update basicScenarioAadMisspelled.scala * Update basicScenarioAadMisspelled.scala
1 parent 3234333 commit 85d9988

File tree

22 files changed

+535
-162
lines changed

22 files changed

+535
-162
lines changed

sdk/cosmos/azure-cosmos-spark-account-data-resolver-sample/sample/AccountTokenResolverSample.ipynb

Lines changed: 56 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,43 +5,48 @@
55
"execution_count": 0,
66
"metadata": {
77
"application/vnd.databricks.v1+cell": {
8-
"cellmetadata": {
9-
"bytelimit": 2048000,
10-
"rowlimit": 10000
8+
"cellMetadata": {
9+
"byteLimit": 2048000,
10+
"rowLimit": 10000
1111
},
12-
"inputwidgets": {},
12+
"inputWidgets": {},
1313
"nuid": "edf2bb67-c3fa-459c-bb4b-83dec2075401",
14-
"showtitle": false,
14+
"showTitle": false,
1515
"title": ""
1616
}
1717
},
1818
"outputs": [],
1919
"source": [
20-
"cosmosendpoint = \"https://fabianm-oltp-spark-workshop-cdb.documents.azure.com:443/\"\n",
21-
"cosmosmasterkey = \"\"\n",
22-
"cosmosserviceprincipalpassword=\"\"\n"
20+
"cosmosendpoint = \"<YourEndpoint>\"\n",
21+
"cosmosmasterkey = \"<YourKey>\"\n",
22+
"cosmosserviceprincipalpassword=\"\"\n",
23+
"accountDataResolverName = \"com.azure.cosmos.spark.samples.MasterKeyAccountDataResolver\"\n",
24+
"#accountDataResolverName = \"com.azure.cosmos.spark.samples.ServicePrincipalAccountDataResolver\"\n",
25+
"#accountDataResolverName = \"com.azure.cosmos.spark.samples.ManagedIdentityAccountDataResolver\""
2326
]
2427
},
2528
{
2629
"cell_type": "code",
2730
"execution_count": 0,
2831
"metadata": {
2932
"application/vnd.databricks.v1+cell": {
30-
"cellmetadata": {
31-
"bytelimit": 2048000,
32-
"rowlimit": 10000
33+
"cellMetadata": {
34+
"byteLimit": 2048000,
35+
"rowLimit": 10000
3336
},
34-
"inputwidgets": {},
37+
"inputWidgets": {},
3538
"nuid": "67f2404c-a6b6-4342-9dac-638a2bd7731c",
36-
"showtitle": false,
39+
"showTitle": false,
3740
"title": ""
3841
}
3942
},
4043
"outputs": [],
4144
"source": [
4245
"import base64\n",
4346
"import os\n",
44-
"cert_file= open(\"/workspace/users/[email protected]/fabianm-spark-auth-sp-cert.pem\",\"rb\")\n",
47+
"\n",
48+
"\n",
49+
"cert_file= open(\"/Workspace/Users/[email protected]/someCert.pem\",\"rb\")\n",
4550
"cert_data_binary = cert_file.read()\n",
4651
"cert_data = (base64.b64encode(cert_data_binary)).decode('ascii')\n"
4752
]
@@ -51,31 +56,22 @@
5156
"execution_count": 0,
5257
"metadata": {
5358
"application/vnd.databricks.v1+cell": {
54-
"cellmetadata": {
55-
"bytelimit": 2048000,
56-
"rowlimit": 10000
59+
"cellMetadata": {
60+
"byteLimit": 2048000,
61+
"rowLimit": 10000
5762
},
58-
"inputwidgets": {},
63+
"inputWidgets": {},
5964
"nuid": "bfbd87f9-7628-489c-8f8a-1f0d5d14d2be",
60-
"showtitle": false,
65+
"showTitle": false,
6166
"title": ""
6267
}
6368
},
64-
"outputs": [
65-
{
66-
"output_type": "stream",
67-
"name": "stdout",
68-
"output_type": "stream",
69-
"text": [
70-
"pk: 7d90716b-4ea1-4753-8090-4b72e4a2b93b\nroot\n |-- id: string (nullable = false)\n |-- pk: string (nullable = false)\n |-- emptycolumn: string (nullable = true)\n |-- nullcolumn: string (nullable = true)\n |-- defaultcolumn: integer (nullable = true)\n |-- largecolumn: string (nullable = true)\n\n+------------------------------------+------------------------------------+-----------+----------+-------------+----------------------------------------------------------------------------------------------------+\n| id| pk|emptycolumn|nullcolumn|defaultcolumn| largecolumn|\n+------------------------------------+------------------------------------+-----------+----------+-------------+----------------------------------------------------------------------------------------------------+\n|fa1a3854-4d41-4ffb-b992-a00c82585ddc|7d90716b-4ea1-4753-8090-4b72e4a2b93b| | null| 0|ixcqsfjhwqelwcpjtzaqquhaxlmemdpeheyfxosdobyqvbihrvrftuaicllsfllgmfzwrbefkszobvpihkqxqfyulggqgrznd...|\n|3ef8e8c0-e9e7-4a2e-887f-2a9826f7b987|7d90716b-4ea1-4753-8090-4b72e4a2b93b| | null| 0|obltfpuoonfywvusviupkloeojqolqqyabzhcssnefwwptgvwqgnajesmnsyslvogtclasksjwpltsqrqwkeqgazarodmvbmv...|\n+------------------------------------+------------------------------------+-----------+----------+-------------+----------------------------------------------------------------------------------------------------+\n\n"
71-
]
72-
}
73-
],
69+
"outputs": [],
7470
"source": [
7571
"import random\n",
7672
"import string\n",
7773
"import uuid\n",
78-
"from pyspark.sql.types import structtype,structfield, stringtype, integertype\n",
74+
"from pyspark.sql.types import StructType,StructField, StringType, IntegerType\n",
7975
" \n",
8076
"def random_string_generator(str_size, allowed_chars):\n",
8177
" return ''.join(random.choice(allowed_chars) for x in range(str_size))\n",
@@ -85,54 +81,55 @@
8581
" \n",
8682
"chars = string.ascii_letters\n",
8783
"data = [\\\n",
88-
" (str(uuid.uuid4()), pk, \"\", none, 0, random_string_generator(16000, chars)),\\\n",
89-
" (str(uuid.uuid4()), pk, \"\", none, 0, random_string_generator(random.randint(16000, 170000), chars)),\\\n",
84+
" (str(uuid.uuid4()), pk, \"\", None, 0, random_string_generator(16000, chars)),\\\n",
85+
" (str(uuid.uuid4()), pk, \"\", None, 0, random_string_generator(random.randint(16000, 170000), chars)),\\\n",
9086
" ]\n",
9187
"\n",
92-
"schema = structtype([ \\\n",
93-
" structfield(\"id\",stringtype(),false), \\\n",
94-
" structfield(\"pk\",stringtype(),false), \\\n",
95-
" structfield(\"emptycolumn\",stringtype(),true), \\\n",
96-
" structfield(\"nullcolumn\",stringtype(),true), \\\n",
97-
" structfield(\"defaultcolumn\",integertype(),true), \\\n",
98-
" structfield(\"largecolumn\", stringtype(), true)\\\n",
88+
"schema = StructType([ \\\n",
89+
" StructField(\"id\",StringType(),False), \\\n",
90+
" StructField(\"pk\",StringType(),False), \\\n",
91+
" StructField(\"emptycolumn\",StringType(),True), \\\n",
92+
" StructField(\"nullcolumn\",StringType(),True), \\\n",
93+
" StructField(\"defaultcolumn\",IntegerType(),True), \\\n",
94+
" StructField(\"largecolumn\", StringType(), True)\\\n",
9995
" ])\n",
10096
" \n",
101-
"df = spark.createdataframe(data=data,schema=schema)\n",
102-
"df.printschema()\n",
97+
"df = spark.createDataFrame(data=data,schema=schema)\n",
98+
"df.printSchema()\n",
10399
"df.show(truncate=100)\n",
104100
"\n",
105101
"writecfg = {\n",
106102
" \"spark.cosmos.accountendpoint\": cosmosendpoint,\n",
107-
" \"spark.cosmos.database\": \"test\",\n",
108-
" \"spark.cosmos.container\": \"testitems\",\n",
103+
" \"spark.cosmos.accountDataResolverServiceName\": accountDataResolverName,\n",
104+
" \"spark.cosmos.database\": \"Test\",\n",
105+
" \"spark.cosmos.container\": \"TestItems\",\n",
109106
" \"spark.cosmos.write.strategy\": \"itemappend\",\n",
110107
" \"spark.cosmos.write.bulk.enabled\": \"true\", \n",
111108
" \"cosmos.auth.sample.enabled\": \"true\",\n",
112109
" # masterkey\n",
113-
" #\"cosmos.auth.sample.authtype\": \"masterkey\",\n",
110+
" #\"cosmos.auth.sample.authType\": \"masterkey\",\n",
114111
" #\"cosmos.auth.sample.key.secret\": cosmosmasterkey,\n",
115112
" #\n",
116113
" # aad auth with managed identity\n",
117-
" #\"cosmos.auth.sample.authtype\": \"managedidentity\",\n",
118-
" #\"cosmos.auth.sample.tenantid\": \"72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
119-
" #\"cosmos.auth.sample.subscriptionid\": \"8fba6d4f-7c37-4d13-9063-fd58ad2b86e2\",\n",
114+
" #\"cosmos.auth.sample.authType\": \"managedidentity\",\n",
115+
" #\"cosmos.auth.sample.tenantId\": \"72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
116+
" #\"cosmos.auth.sample.subscriptionId\": \"8fba6d4f-7c37-4d13-9063-fd58ad2b86e2\",\n",
120117
" #\"cosmos.auth.sample.resourcegroupname\": \"fabianm-oltp-spark-workshop\"\n",
121118
" #\n",
122119
" # aad auth with service principal (password)\n",
123-
" #\"cosmos.auth.sample.authtype\": \"serviceprincipal\",\n",
124-
" #\"cosmos.auth.sample.tenantid\": \"72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
125-
" #\"cosmos.auth.sample.subscriptionid\": \"8fba6d4f-7c37-4d13-9063-fd58ad2b86e2\",\n",
120+
" #\"cosmos.auth.sample.authType\": \"serviceprincipal\",\n",
121+
" #\"cosmos.auth.sample.tenantId\": \"72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
122+
" #\"cosmos.auth.sample.subscriptionId\": \"8fba6d4f-7c37-4d13-9063-fd58ad2b86e2\",\n",
126123
" #\"cosmos.auth.sample.resourcegroupname\": \"fabianm-oltp-spark-workshop\",\n",
127-
" #\"cosmos.auth.sample.serviceprincipal.clientid\": \"bd559cf4-786d-43ae-9ff6-eb83c5952c73\",\n",
124+
" #\"cosmos.auth.sample.serviceprincipal.clientId\": \"bd559cf4-786d-43ae-9ff6-eb83c5952c73\",\n",
128125
" #\"cosmos.auth.sample.serviceprincipal.clientsecret\": cosmosserviceprincipalpassword\n",
129126
" #\n",
130127
" # aad auth with service principal (cert)\n",
131-
" #\"cosmos.auth.sample.authtype\": \"serviceprincipal\",\n",
132-
" #\"cosmos.auth.sample.tenantid\": \"72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
133-
" #\"cosmos.auth.sample.subscriptionid\": \"8fba6d4f-7c37-4d13-9063-fd58ad2b86e2\",\n",
128+
" #\"cosmos.auth.sample.authType\": \"serviceprincipal\",\n",
129+
" #\"cosmos.auth.sample.tenantId\": \"72f988bf-86f1-41af-91ab-2d7cd011db47\",\n",
130+
" #\"cosmos.auth.sample.subscriptionId\": \"8fba6d4f-7c37-4d13-9063-fd58ad2b86e2\",\n",
134131
" #\"cosmos.auth.sample.resourcegroupname\": \"fabianm-oltp-spark-workshop\",\n",
135-
" #\"cosmos.auth.sample.serviceprincipal.clientid\": \"88436299-945f-4824-8183-2cbddf981388\",\n",
132+
" #\"cosmos.auth.sample.serviceprincipal.clientId\": \"88436299-945f-4824-8183-2cbddf981388\",\n",
136133
" #\"cosmos.auth.sample.serviceprincipal.cert\": cert_data\n",
137134
"}\n",
138135
"\n",
@@ -148,17 +145,12 @@
148145
"metadata": {
149146
"application/vnd.databricks.v1+notebook": {
150147
"dashboards": [],
148+
"environmentMetadata": null,
151149
"language": "python",
152-
"notebookmetadata": {
153-
"mostrecentlyexecutedcommandwithimplicitdf": {
154-
"commandid": 3298457839905717,
155-
"dataframes": [
156-
"_sqldf"
157-
]
158-
},
159-
"pythonindentunit": 4
150+
"notebookMetadata": {
151+
"pythonIndentUnit": 4
160152
},
161-
"notebookname": "accounttokenresolversample",
153+
"notebookName": "AccountTokenResolverSample",
162154
"widgets": {}
163155
}
164156
},
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
1-
com.azure.cosmos.spark.samples.SampleAccountDataResolver
1+
com.azure.cosmos.spark.samples.ManagedIdentityAccountDataResolver
2+
com.azure.cosmos.spark.samples.MasterKeyAccountDataResolver
3+
com.azure.cosmos.spark.samples.ServicePrincipalAccountDataResolver
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
package com.azure.cosmos.spark.samples
4+
5+
import com.azure.core.credential.{TokenCredential, TokenRequestContext}
6+
import com.azure.cosmos.spark.{AccountDataResolver, CosmosAccessToken}
7+
import com.azure.identity.ManagedIdentityCredentialBuilder
8+
9+
// scalastyle:off underscore.import
10+
import scala.collection.JavaConverters._
11+
// scalastyle:on underscore.import
12+
13+
class ManagedIdentityAccountDataResolver extends AccountDataResolver with BasicLoggingTrait {
14+
override def getAccountDataConfig(configs: Map[String, String]): Map[String, String] = {
15+
if (isEnabled(configs)) {
16+
configs +
17+
("spark.cosmos.auth.type" -> "AccessToken") +
18+
("spark.cosmos.account.tenantId" -> getRequiredConfig(configs, SampleConfigNames.TenantId)) +
19+
("spark.cosmos.account.subscriptionId" -> getRequiredConfig(configs, SampleConfigNames.SubscriptionId)) +
20+
("spark.cosmos.account.resourceGroupName" -> getRequiredConfig(configs, SampleConfigNames.ResourceGroupName))
21+
} else {
22+
configs
23+
}
24+
}
25+
26+
private def getRequiredConfig(configs: Map[String, String], configName: String): String = {
27+
val valueOpt = configs.get(configName)
28+
assert(valueOpt.isDefined, s"Parameter '$configName' is missing.")
29+
valueOpt.get
30+
}
31+
32+
private def isEnabled(configs: Map[String, String]): Boolean = {
33+
val enabled = configs.get(SampleConfigNames.CustomAuthEnabled)
34+
enabled.isDefined && enabled.get.toBoolean
35+
}
36+
37+
private def getManagedIdentityTokenCredential(configs: Map[String, String]): Option[TokenCredential] = {
38+
logInfo(s"Constructing ManagedIdentity TokenCredential")
39+
val tokenCredentialBuilder = new ManagedIdentityCredentialBuilder()
40+
if (configs.contains(SampleConfigNames.ManagedIdentityClientId)) {
41+
tokenCredentialBuilder.clientId(configs(SampleConfigNames.ManagedIdentityClientId))
42+
}
43+
44+
if (configs.contains(SampleConfigNames.ManagedIdentityResourceId)) {
45+
tokenCredentialBuilder.resourceId(configs(SampleConfigNames.ManagedIdentityResourceId))
46+
}
47+
48+
Some(tokenCredentialBuilder.build())
49+
}
50+
51+
private def getTokenCredential(configs: Map[String, String]): Option[TokenCredential] = {
52+
val authType = getRequiredConfig(configs, SampleConfigNames.AuthType)
53+
if (authType.equalsIgnoreCase(SampleAuthTypes.ManagedIdentity)) {
54+
logInfo(s"Managed identity used")
55+
getManagedIdentityTokenCredential(configs)
56+
} else {
57+
logError(s"Invalid authType '$authType'.")
58+
assert(assertion = false, s"Invalid authType '$authType'.")
59+
None
60+
}
61+
}
62+
63+
override def getAccessTokenProvider(configs: Map[String, String]): Option[List[String] => CosmosAccessToken] = {
64+
if (isEnabled(configs)) {
65+
val tokenCredential = getTokenCredential(configs)
66+
67+
if (tokenCredential.isDefined) {
68+
logInfo(s"TokenCredential found - and access token provider used")
69+
Some((tokenRequestContextStrings: List[String]) => {
70+
val tokenRequestContext = new TokenRequestContext
71+
tokenRequestContext.setScopes(tokenRequestContextStrings.asJava)
72+
val accessToken = tokenCredential
73+
.get
74+
.getToken(tokenRequestContext)
75+
.block()
76+
CosmosAccessToken(accessToken.getToken, accessToken.getExpiresAt)
77+
})
78+
} else {
79+
logWarning(s"No TokenCredential provided")
80+
None
81+
}
82+
} else {
83+
logInfo(s"SampleAccountDataResolver is disabled")
84+
None
85+
}
86+
}
87+
88+
private[this] object SampleConfigNames {
89+
val AuthType = "cosmos.auth.sample.authType"
90+
val CustomAuthEnabled = "cosmos.auth.sample.enabled"
91+
val ManagedIdentityClientId = "cosmos.auth.sample.managedIdentity.clientId"
92+
val ManagedIdentityResourceId = "cosmos.auth.sample.managedIdentity.resourceId"
93+
val ResourceGroupName = "cosmos.auth.sample.resourceGroupName"
94+
val SubscriptionId = "cosmos.auth.sample.subscriptionId"
95+
val TenantId = "cosmos.auth.sample.tenantId"
96+
}
97+
98+
private[this] object SampleAuthTypes {
99+
val ManagedIdentity: String = "managedidentity"
100+
}
101+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
package com.azure.cosmos.spark.samples
4+
5+
import com.azure.core.credential.{TokenCredential, TokenRequestContext}
6+
import com.azure.cosmos.spark.{AccountDataResolver, CosmosAccessToken}
7+
import com.azure.identity.{ClientCertificateCredentialBuilder, ClientSecretCredentialBuilder, ManagedIdentityCredentialBuilder}
8+
9+
import java.io.ByteArrayInputStream
10+
import java.util.Base64
11+
12+
// scalastyle:off underscore.import
13+
import scala.collection.JavaConverters._
14+
// scalastyle:on underscore.import
15+
16+
class MasterKeyAccountDataResolver extends AccountDataResolver with BasicLoggingTrait {
17+
override def getAccountDataConfig(configs: Map[String, String]): Map[String, String] = {
18+
if (isEnabled(configs)) {
19+
configs + ("spark.cosmos.accountKey" -> getRequiredConfig(configs, SampleConfigNames.MasterKeySecret))
20+
} else {
21+
configs
22+
}
23+
}
24+
25+
private def getRequiredConfig(configs: Map[String, String], configName: String): String = {
26+
val valueOpt = configs.get(configName)
27+
assert(valueOpt.isDefined, s"Parameter '$configName' is missing.")
28+
valueOpt.get
29+
}
30+
31+
private def isEnabled(configs: Map[String, String]): Boolean = {
32+
val enabled = configs.get(SampleConfigNames.CustomAuthEnabled)
33+
enabled.isDefined && enabled.get.toBoolean
34+
}
35+
private[this] object SampleConfigNames {
36+
val CustomAuthEnabled = "cosmos.auth.sample.enabled"
37+
val MasterKeySecret = "cosmos.auth.sample.key.secret"
38+
}
39+
40+
/**
41+
* This method will be invoked by the Cosmos DB Spark connector to retrieve access tokens. It will only
42+
* be used when the config `spark.cosmos.auth.type` is set to `AccessToken` - and in this case
43+
* the implementation of this trait will need to provide a function that can be used to produce
44+
* access tokens or None in the case that for the specified configuration no auth can be provided.
45+
*
46+
* @param configs the user configuration originally provided
47+
* @return A function that can be used to provide access tokens
48+
*/
49+
override def getAccessTokenProvider(configs: Map[String, String]): Option[List[String] => CosmosAccessToken] = ???
50+
}

0 commit comments

Comments
 (0)