diff --git a/build.sbt b/build.sbt index fa1d84c70..6827d78c9 100644 --- a/build.sbt +++ b/build.sbt @@ -12,7 +12,9 @@ libraryDependencies += "org.apache.spark" %% "spark-core" % sparkVersion % Provi libraryDependencies += "org.apache.spark" %% "spark-sql" % sparkVersion % Provided libraryDependencies += "org.apache.spark" %% "spark-hive" % sparkVersion % Provided libraryDependencies += "com.databricks" % "dbutils-api_2.12" % "0.0.5" % Provided -libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.11.595" % Provided +libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.11.595" +libraryDependencies += "com.amazonaws" % "aws-java-sdk-secretsmanager" % "1.11.595" +libraryDependencies += "com.fasterxml.jackson.core" % "jackson-databind" % "2.10.0" libraryDependencies += "io.delta" % "delta-core_2.12" % "1.0.0" % Provided libraryDependencies += "org.scalaj" %% "scalaj-http" % "2.4.2" //libraryDependencies += "org.apache.hive" % "hive-metastore" % "2.3.9" diff --git a/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala b/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala index cfa3f61f3..ded53367e 100644 --- a/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala +++ b/src/main/scala/com/databricks/labs/overwatch/ParamDeserializer.scala @@ -88,14 +88,25 @@ class ParamDeserializer() extends StdDeserializer[OverwatchParams](classOf[Overw override def deserialize(jp: JsonParser, ctxt: DeserializationContext): OverwatchParams = { val masterNode = jp.getCodec.readTree[JsonNode](jp) - val token = try { - Some(TokenSecret( - masterNode.get("tokenSecret").get("scope").asText(), - masterNode.get("tokenSecret").get("key").asText())) - } catch { - case e: Throwable => - println("No Token Secret Defined", e) - None + // TODO: consider keeping enum with specific secrets inner structure and below + // transform to function processing the enum in a loop + val token = { + + val databricksToken = + for { + scope <- getOptionString(masterNode,"tokenSecret.scope") + key <- getOptionString(masterNode, "tokenSecret.key") + } yield TokenSecret(scope, key) + + val finalToken = if (databricksToken.isEmpty) + for { + secretId <- getOptionString(masterNode,"tokenSecret.secretId") + region <- getOptionString(masterNode,"tokenSecret.region") + apiToken = getOptionString(masterNode,"tokenSecret.tokenKey") + } yield AwsTokenSecret(secretId, region, apiToken) + else databricksToken + + finalToken } val rawAuditPath = getOptionString(masterNode, "auditLogConfig.rawAuditPath") diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala index 3db555902..418469555 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Initializer.scala @@ -265,7 +265,7 @@ class Initializer(config: Config) extends SparkSessionWrapper { config.setExternalizeOptimize(rawParams.externalizeOptimize) val overwatchScope = rawParams.overwatchScope.getOrElse(Seq("all")) - val tokenSecret = rawParams.tokenSecret + // TODO -- PRIORITY -- If data target is null -- default table gets dbfs:/null val dataTarget = rawParams.dataTarget.getOrElse( DataTarget(Some("overwatch"), Some("dbfs:/user/hive/warehouse/overwatch.db"), None)) @@ -275,24 +275,30 @@ class Initializer(config: Config) extends SparkSessionWrapper { if (overwatchScope.head == "all") config.setOverwatchScope(config.orderedOverwatchScope) else config.setOverwatchScope(validateScope(overwatchScope)) - // validate token secret requirements - // TODO - Validate if token has access to necessary assets. Warn/Fail if not - if (tokenSecret.nonEmpty && !disableValidations && !config.isLocalTesting) { - if (tokenSecret.get.scope.isEmpty || tokenSecret.get.key.isEmpty) { - throw new BadConfigException(s"Secret AND Key must be provided together or neither of them. " + - s"Either supply both or neither.") + if (rawParams.tokenSecret.nonEmpty && !disableValidations && !config.isLocalTesting) { + rawParams.tokenSecret.map { + case databricksSecret: TokenSecret => + // validate token secret requirements + // TODO - Validate if databricks token has access to necessary assets. Warn/Fail if not + + if (databricksSecret.scope.isEmpty || databricksSecret.key.isEmpty) { + throw new BadConfigException(s"Secret AND Key must be provided together or neither of them. " + + s"Either supply both or neither.") + } + val scopeCheck = dbutils.secrets.listScopes().map(_.getName()).toArray.filter(_ == databricksSecret.scope) + if (scopeCheck.length == 0) throw new BadConfigException(s"Scope ${databricksSecret.scope} does not exist " + + s"in this workspace. Please provide a scope available and accessible to this account.") + val scopeName = scopeCheck.head + + val keyCheck = dbutils.secrets.list(scopeName).toArray.filter(_.key == databricksSecret.key) + if (keyCheck.length == 0) throw new BadConfigException(s"Key ${databricksSecret.key} does not exist " + + s"within the provided scope: ${databricksSecret.scope}. Please provide a scope and key " + + s"available and accessible to this account.") + + config.registerWorkspaceMeta(Some(TokenSecret(scopeName, keyCheck.head.key))) + + case awsSecret: AwsTokenSecret => config.registerWorkspaceMeta(Some(awsSecret)) } - val scopeCheck = dbutils.secrets.listScopes().map(_.getName()).toArray.filter(_ == tokenSecret.get.scope) - if (scopeCheck.length == 0) throw new BadConfigException(s"Scope ${tokenSecret.get.scope} does not exist " + - s"in this workspace. Please provide a scope available and accessible to this account.") - val scopeName = scopeCheck.head - - val keyCheck = dbutils.secrets.list(scopeName).toArray.filter(_.key == tokenSecret.get.key) - if (keyCheck.length == 0) throw new BadConfigException(s"Key ${tokenSecret.get.key} does not exist " + - s"within the provided scope: ${tokenSecret.get.scope}. Please provide a scope and key " + - s"available and accessible to this account.") - - config.registerWorkspaceMeta(Some(TokenSecret(scopeName, keyCheck.head.key))) } else config.registerWorkspaceMeta(None) // Validate data Target diff --git a/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala b/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala index 2b3bbaf9a..a0d3ad6c2 100644 --- a/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala +++ b/src/main/scala/com/databricks/labs/overwatch/pipeline/Module.scala @@ -4,6 +4,7 @@ import com.databricks.labs.overwatch.pipeline.TransformFunctions._ import com.databricks.labs.overwatch.utils._ import org.apache.log4j.{Level, Logger} import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ class Module( val moduleId: Int, @@ -172,10 +173,27 @@ class Module( initState } + private def normalizeToken(secretToken: TokenSecret, reportDf: DataFrame): DataFrame = { + val inputConfigCols = reportDf.select($"inputConfig.*") + .columns + .filter(_!="tokenSecret") + .map(name => col("inputConfig."+name)) + + reportDf + .withColumn( + "inputConfig", + struct(inputConfigCols:+struct(lit(secretToken.scope),lit(secretToken.key)).as("tokenSecret"):_*) + ) + } + private def finalizeModule(report: ModuleStatusReport): Unit = { pipeline.updateModuleState(report.simple) if (!pipeline.readOnly) { - pipeline.database.write(Seq(report).toDF, pipeline.pipelineStateTarget, pipeline.pipelineSnapTime.asColumnTS) + val secretToken = SecretTools(report.inputConfig.tokenSecret.get).getTargetTableStruct + val targetDf = normalizeToken(secretToken, Seq(report).toDF) + pipeline.database.write( + targetDf, + pipeline.pipelineStateTarget, pipeline.pipelineSnapTime.asColumnTS) } } diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/AwsSecrets.scala b/src/main/scala/com/databricks/labs/overwatch/utils/AwsSecrets.scala new file mode 100644 index 000000000..132cdf797 --- /dev/null +++ b/src/main/scala/com/databricks/labs/overwatch/utils/AwsSecrets.scala @@ -0,0 +1,46 @@ +package com.databricks.labs.overwatch.utils + +import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder +import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest +import org.apache.log4j.{Level, Logger} +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods.parse + +import java.util.Base64 + +object AwsSecrets { + private val logger: Logger = Logger.getLogger(this.getClass) + + def readApiToken(secretId: String, region: String, apiTokenKey: Option[String]): String = { + apiTokenKey match { + case Some(key) => secretValueAsMap(secretId, region) + .getOrElse(key, throw new IllegalStateException("apiTokenKey param not found")) + .asInstanceOf[String] + case None => readRawSecretFromAws(secretId, region) + } + } + + def secretValueAsMap(secretId: String, region: String = "us-east-2"): Map[String, Any] = + parseJsonToMap(readRawSecretFromAws(secretId, region)) + + def readRawSecretFromAws(secretId: String, region: String): String = { + logger.log(Level.INFO, s"Looking up secret $secretId in AWS Secret Manager") + + val secretsClient = AWSSecretsManagerClientBuilder + .standard() + .withRegion(region) + .build() + val request = new GetSecretValueRequest().withSecretId(secretId) + val secretValue = secretsClient.getSecretValue(request) + + if (secretValue.getSecretString != null) + secretValue.getSecretString + else + new String(Base64.getDecoder.decode(secretValue.getSecretBinary).array) + } + + def parseJsonToMap(jsonStr: String): Map[String, Any] = { + implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats + parse(jsonStr).extract[Map[String, Any]] + } +} diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala index 39ca73460..1e2876a13 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Config.scala @@ -312,20 +312,14 @@ class Config() { * as the job owner or notebook user (if called from notebook) * @return */ - private[overwatch] def registerWorkspaceMeta(tokenSecret: Option[TokenSecret]): this.type = { + private[overwatch] def registerWorkspaceMeta(tokenSecret: Option[TokenSecretContainer]): this.type = { var rawToken = "" - var scope = "" - var key = "" try { // Token secrets not supported in local testing if (tokenSecret.nonEmpty && !_isLocalTesting) { // not local testing and secret passed _workspaceUrl = dbutils.notebook.getContext().apiUrl.get _cloudProvider = if (_workspaceUrl.toLowerCase().contains("azure")) "azure" else "aws" - scope = tokenSecret.get.scope - key = tokenSecret.get.key - rawToken = dbutils.secrets.get(scope, key) - val authMessage = s"Valid Secret Identified: Executing with token located in secret, $scope : $key" - logger.log(Level.INFO, authMessage) + rawToken = SecretTools(tokenSecret.get).getApiToken _tokenType = "Secret" } else { if (_isLocalTesting) { // Local testing env vars @@ -344,7 +338,7 @@ class Config() { } } if (!rawToken.matches("^(dapi|dkea)[a-zA-Z0-9-]*$")) throw new BadConfigException(s"contents of secret " + - s"at scope:key $scope:$key is not in a valid format. Please validate the contents of your secret. It must be " + + s"is not in a valid format. Please validate the contents of your secret. It must be " + s"a user access token. It should start with 'dapi' ") setApiEnv(ApiEnv(isLocalTesting, workspaceURL, rawToken, packageVersion)) this diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/SecretTools.scala b/src/main/scala/com/databricks/labs/overwatch/utils/SecretTools.scala new file mode 100644 index 000000000..3adb9f42f --- /dev/null +++ b/src/main/scala/com/databricks/labs/overwatch/utils/SecretTools.scala @@ -0,0 +1,57 @@ +package com.databricks.labs.overwatch.utils + +import com.databricks.dbutils_v1.DBUtilsHolder.dbutils +import org.apache.log4j.{Level, Logger} + +/** + * SecretTools handles common functionality related to working with secrets: + * 1) Get Databricks API token stored in specified secret + * 2) Normalize secret structure to be stored at Delta table pipeline_report under inputConfig.tokenSecret nested struct column + * There are two secret types available now - AWS Secrets Manager, Databricks secrets + */ +trait SecretTools[T <: TokenSecretContainer] { + def getApiToken : String + def getTargetTableStruct: TokenSecret +} + +object SecretTools { + private val logger: Logger = Logger.getLogger(this.getClass) + type DatabricksTokenSecret = TokenSecret + + private class DatabricksSecretTools(tokenSecret : DatabricksTokenSecret) extends SecretTools[DatabricksTokenSecret] { + override def getApiToken: String = { + val scope = tokenSecret.scope + val key = tokenSecret.key + val authMessage = s"Executing with Databricks token located in secret, scope=$scope : key=$key" + logger.log(Level.INFO, authMessage) + dbutils.secrets.get(scope, key) + } + + override def getTargetTableStruct: TokenSecret = { + TokenSecret(tokenSecret.scope,tokenSecret.key) + } + } + + private class AwsSecretTools(tokenSecret : AwsTokenSecret) extends SecretTools[AwsTokenSecret] { + override def getApiToken: String = { + val secretId = tokenSecret.secretId + val region = tokenSecret.region + val tokenKey = tokenSecret.tokenKey + val authMessage = s"Executing with AWS token located in secret, secretId=$secretId : region=$region : tokenKey=$tokenKey" + logger.log(Level.INFO, authMessage) + AwsSecrets.readApiToken(secretId, region, tokenSecret.tokenKey) + } + + override def getTargetTableStruct: TokenSecret = { + TokenSecret(tokenSecret.region, tokenSecret.secretId) + } + } + + def apply(secretSource: TokenSecretContainer): SecretTools[_] = { + secretSource match { + case x: AwsTokenSecret => new AwsSecretTools(x) + case y: DatabricksTokenSecret => new DatabricksSecretTools(y) + case _ => throw new IllegalArgumentException(s"${secretSource.toString} not implemented") + } + } +} diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala index 17bcb0ba9..46b6c6083 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Structures.scala @@ -21,7 +21,9 @@ case class SparkDetail() case class GangliaDetail() -case class TokenSecret(scope: String, key: String) +abstract class TokenSecretContainer extends Product with Serializable +case class TokenSecret(scope: String, key: String) extends TokenSecretContainer +case class AwsTokenSecret(secretId: String, region: String, tokenKey: Option[String] = None) extends TokenSecretContainer case class DataTarget(databaseName: Option[String], databaseLocation: Option[String], etlDataPathPrefix: Option[String], consumerDatabaseName: Option[String] = None, consumerDatabaseLocation: Option[String] = None) @@ -75,7 +77,7 @@ case class AuditLogConfig( case class IntelligentScaling(enabled: Boolean = false, minimumCores: Int = 4, maximumCores: Int = 512, coeff: Double = 1.0) case class OverwatchParams(auditLogConfig: AuditLogConfig, - tokenSecret: Option[TokenSecret] = None, + tokenSecret: Option[TokenSecretContainer] = None, dataTarget: Option[DataTarget] = None, badRecordsPath: Option[String] = None, overwatchScope: Option[Seq[String]] = None, @@ -349,9 +351,15 @@ object OverwatchEncoders { implicit def overwatchScope: org.apache.spark.sql.Encoder[OverwatchScope] = org.apache.spark.sql.Encoders.kryo[OverwatchScope] + /* implicit def tokenSecret: org.apache.spark.sql.Encoder[TokenSecret] = org.apache.spark.sql.Encoders.kryo[TokenSecret] + implicit def tokenSecretContainer: org.apache.spark.sql.Encoder[TokenSecretContainer] = + org.apache.spark.sql.Encoders.kryo[TokenSecretContainer] + + */ + implicit def dataTarget: org.apache.spark.sql.Encoder[DataTarget] = org.apache.spark.sql.Encoders.kryo[DataTarget] diff --git a/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala b/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala index 350f97b88..76e887858 100644 --- a/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala +++ b/src/main/scala/com/databricks/labs/overwatch/utils/Tools.scala @@ -11,7 +11,7 @@ import com.fasterxml.jackson.module.scala.DefaultScalaModule import io.delta.tables.DeltaTable import org.apache.commons.lang3.StringEscapeUtils import org.apache.hadoop.conf._ -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.log4j.{Level, Logger} import org.apache.spark.sql.functions._ import org.apache.spark.util.SerializableConfiguration diff --git a/src/test/scala/com/databricks/labs/overwatch/ParamDeserializerTest.scala b/src/test/scala/com/databricks/labs/overwatch/ParamDeserializerTest.scala index f7cc714d5..890eeeb56 100644 --- a/src/test/scala/com/databricks/labs/overwatch/ParamDeserializerTest.scala +++ b/src/test/scala/com/databricks/labs/overwatch/ParamDeserializerTest.scala @@ -12,6 +12,34 @@ class ParamDeserializerTest extends AnyFunSpec { describe("ParamDeserializer") { + val paramModule: SimpleModule = new SimpleModule() + .addDeserializer(classOf[OverwatchParams], new ParamDeserializer) + val mapper: ObjectMapper with ScalaObjectMapper = (new ObjectMapper() with ScalaObjectMapper) + .registerModule(DefaultScalaModule) + .registerModule(paramModule) + .asInstanceOf[ObjectMapper with ScalaObjectMapper] + + it("should decode passed token string as AWS secrets") { + val AWSsecrets = """ + |{"tokenSecret":{"secretId":"overwatch","region":"us-east-2"}} + |""".stripMargin + + + val expected = Some(AwsTokenSecret("overwatch", "us-east-2")) + val parsed = mapper.readValue[OverwatchParams](AWSsecrets) + assertResult(expected)(parsed.tokenSecret) + } + + it("should decode passed token string as Databricks secrets") { + val Databrickssecrets = """ + |{"tokenSecret":{"scope":"overwatch", "key":"test-key"}} + |""".stripMargin + + val expected = Some(TokenSecret("overwatch", "test-key")) + val parsed = mapper.readValue[OverwatchParams](Databrickssecrets) + assertResult(expected)(parsed.tokenSecret) + } + it("should decode incomplete parameters") { val incomplete = """ |{"auditLogConfig":{"azureAuditLogEventhubConfig":{"connectionString":"test","eventHubName":"overwatch-evhub", @@ -24,13 +52,6 @@ class ParamDeserializerTest extends AnyFunSpec { |"workspace_name":"myTestWorkspace", "externalizeOptimizations":"false"} |""".stripMargin - val paramModule: SimpleModule = new SimpleModule() - .addDeserializer(classOf[OverwatchParams], new ParamDeserializer) - val mapper: ObjectMapper with ScalaObjectMapper = (new ObjectMapper() with ScalaObjectMapper) - .registerModule(DefaultScalaModule) - .registerModule(paramModule) - .asInstanceOf[ObjectMapper with ScalaObjectMapper] - val expected = OverwatchParams( AuditLogConfig( azureAuditLogEventhubConfig = Some(AzureAuditLogEventhubConfig(