Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ libraryDependencies += "org.apache.spark" %% "spark-core" % sparkVersion % Provi
libraryDependencies += "org.apache.spark" %% "spark-sql" % sparkVersion % Provided
libraryDependencies += "org.apache.spark" %% "spark-hive" % sparkVersion % Provided
libraryDependencies += "com.databricks" % "dbutils-api_2.12" % "0.0.5" % Provided
libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.11.595" % Provided
libraryDependencies += "com.amazonaws" % "aws-java-sdk-s3" % "1.11.595"
libraryDependencies += "com.amazonaws" % "aws-java-sdk-secretsmanager" % "1.11.595"
libraryDependencies += "com.fasterxml.jackson.core" % "jackson-databind" % "2.10.0"
libraryDependencies += "io.delta" % "delta-core_2.12" % "1.0.0" % Provided
libraryDependencies += "org.scalaj" %% "scalaj-http" % "2.4.2"
//libraryDependencies += "org.apache.hive" % "hive-metastore" % "2.3.9"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,14 +88,25 @@ class ParamDeserializer() extends StdDeserializer[OverwatchParams](classOf[Overw
override def deserialize(jp: JsonParser, ctxt: DeserializationContext): OverwatchParams = {
val masterNode = jp.getCodec.readTree[JsonNode](jp)

val token = try {
Some(TokenSecret(
masterNode.get("tokenSecret").get("scope").asText(),
masterNode.get("tokenSecret").get("key").asText()))
} catch {
case e: Throwable =>
println("No Token Secret Defined", e)
None
// TODO: consider keeping enum with specific secrets inner structure and below
// transform to function processing the enum in a loop
val token = {

val databricksToken =
for {
scope <- getOptionString(masterNode,"tokenSecret.scope")
key <- getOptionString(masterNode, "tokenSecret.key")
} yield TokenSecret(scope, key)

val finalToken = if (databricksToken.isEmpty)
for {
secretId <- getOptionString(masterNode,"tokenSecret.secretId")
region <- getOptionString(masterNode,"tokenSecret.region")
apiToken = getOptionString(masterNode,"tokenSecret.tokenKey")
} yield AwsTokenSecret(secretId, region, apiToken)
else databricksToken

finalToken
}

val rawAuditPath = getOptionString(masterNode, "auditLogConfig.rawAuditPath")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class Initializer(config: Config) extends SparkSessionWrapper {
config.setExternalizeOptimize(rawParams.externalizeOptimize)

val overwatchScope = rawParams.overwatchScope.getOrElse(Seq("all"))
val tokenSecret = rawParams.tokenSecret

// TODO -- PRIORITY -- If data target is null -- default table gets dbfs:/null
val dataTarget = rawParams.dataTarget.getOrElse(
DataTarget(Some("overwatch"), Some("dbfs:/user/hive/warehouse/overwatch.db"), None))
Expand All @@ -275,24 +275,30 @@ class Initializer(config: Config) extends SparkSessionWrapper {
if (overwatchScope.head == "all") config.setOverwatchScope(config.orderedOverwatchScope)
else config.setOverwatchScope(validateScope(overwatchScope))

// validate token secret requirements
// TODO - Validate if token has access to necessary assets. Warn/Fail if not
if (tokenSecret.nonEmpty && !disableValidations && !config.isLocalTesting) {
if (tokenSecret.get.scope.isEmpty || tokenSecret.get.key.isEmpty) {
throw new BadConfigException(s"Secret AND Key must be provided together or neither of them. " +
s"Either supply both or neither.")
if (rawParams.tokenSecret.nonEmpty && !disableValidations && !config.isLocalTesting) {
rawParams.tokenSecret.map {
case databricksSecret: TokenSecret =>
// validate token secret requirements
// TODO - Validate if databricks token has access to necessary assets. Warn/Fail if not

if (databricksSecret.scope.isEmpty || databricksSecret.key.isEmpty) {
throw new BadConfigException(s"Secret AND Key must be provided together or neither of them. " +
s"Either supply both or neither.")
}
val scopeCheck = dbutils.secrets.listScopes().map(_.getName()).toArray.filter(_ == databricksSecret.scope)
if (scopeCheck.length == 0) throw new BadConfigException(s"Scope ${databricksSecret.scope} does not exist " +
s"in this workspace. Please provide a scope available and accessible to this account.")
val scopeName = scopeCheck.head

val keyCheck = dbutils.secrets.list(scopeName).toArray.filter(_.key == databricksSecret.key)
if (keyCheck.length == 0) throw new BadConfigException(s"Key ${databricksSecret.key} does not exist " +
s"within the provided scope: ${databricksSecret.scope}. Please provide a scope and key " +
s"available and accessible to this account.")

config.registerWorkspaceMeta(Some(TokenSecret(scopeName, keyCheck.head.key)))

case awsSecret: AwsTokenSecret => config.registerWorkspaceMeta(Some(awsSecret))
}
val scopeCheck = dbutils.secrets.listScopes().map(_.getName()).toArray.filter(_ == tokenSecret.get.scope)
if (scopeCheck.length == 0) throw new BadConfigException(s"Scope ${tokenSecret.get.scope} does not exist " +
s"in this workspace. Please provide a scope available and accessible to this account.")
val scopeName = scopeCheck.head

val keyCheck = dbutils.secrets.list(scopeName).toArray.filter(_.key == tokenSecret.get.key)
if (keyCheck.length == 0) throw new BadConfigException(s"Key ${tokenSecret.get.key} does not exist " +
s"within the provided scope: ${tokenSecret.get.scope}. Please provide a scope and key " +
s"available and accessible to this account.")

config.registerWorkspaceMeta(Some(TokenSecret(scopeName, keyCheck.head.key)))
} else config.registerWorkspaceMeta(None)

// Validate data Target
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.databricks.labs.overwatch.pipeline.TransformFunctions._
import com.databricks.labs.overwatch.utils._
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._

class Module(
val moduleId: Int,
Expand Down Expand Up @@ -172,10 +173,27 @@ class Module(
initState
}

private def normalizeToken(secretToken: TokenSecret, reportDf: DataFrame): DataFrame = {
val inputConfigCols = reportDf.select($"inputConfig.*")
.columns
.filter(_!="tokenSecret")
.map(name => col("inputConfig."+name))

reportDf
.withColumn(
"inputConfig",
struct(inputConfigCols:+struct(lit(secretToken.scope),lit(secretToken.key)).as("tokenSecret"):_*)
)
}

private def finalizeModule(report: ModuleStatusReport): Unit = {
pipeline.updateModuleState(report.simple)
if (!pipeline.readOnly) {
pipeline.database.write(Seq(report).toDF, pipeline.pipelineStateTarget, pipeline.pipelineSnapTime.asColumnTS)
val secretToken = SecretTools(report.inputConfig.tokenSecret.get).getTargetTableStruct
val targetDf = normalizeToken(secretToken, Seq(report).toDF)
pipeline.database.write(
targetDf,
pipeline.pipelineStateTarget, pipeline.pipelineSnapTime.asColumnTS)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.databricks.labs.overwatch.utils

import com.amazonaws.services.secretsmanager.AWSSecretsManagerClientBuilder
import com.amazonaws.services.secretsmanager.model.GetSecretValueRequest
import org.apache.log4j.{Level, Logger}
import org.json4s.DefaultFormats
import org.json4s.jackson.JsonMethods.parse

import java.util.Base64

object AwsSecrets {
private val logger: Logger = Logger.getLogger(this.getClass)

def readApiToken(secretId: String, region: String, apiTokenKey: Option[String]): String = {
apiTokenKey match {
case Some(key) => secretValueAsMap(secretId, region)
.getOrElse(key, throw new IllegalStateException("apiTokenKey param not found"))
.asInstanceOf[String]
case None => readRawSecretFromAws(secretId, region)
}
}

def secretValueAsMap(secretId: String, region: String = "us-east-2"): Map[String, Any] =
parseJsonToMap(readRawSecretFromAws(secretId, region))

def readRawSecretFromAws(secretId: String, region: String): String = {
logger.log(Level.INFO, s"Looking up secret $secretId in AWS Secret Manager")

val secretsClient = AWSSecretsManagerClientBuilder
.standard()
.withRegion(region)
.build()
val request = new GetSecretValueRequest().withSecretId(secretId)
val secretValue = secretsClient.getSecretValue(request)

if (secretValue.getSecretString != null)
secretValue.getSecretString
else
new String(Base64.getDecoder.decode(secretValue.getSecretBinary).array)
}

def parseJsonToMap(jsonStr: String): Map[String, Any] = {
implicit val formats: DefaultFormats.type = org.json4s.DefaultFormats
parse(jsonStr).extract[Map[String, Any]]
}
}
12 changes: 3 additions & 9 deletions src/main/scala/com/databricks/labs/overwatch/utils/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -312,20 +312,14 @@ class Config() {
* as the job owner or notebook user (if called from notebook)
* @return
*/
private[overwatch] def registerWorkspaceMeta(tokenSecret: Option[TokenSecret]): this.type = {
private[overwatch] def registerWorkspaceMeta(tokenSecret: Option[TokenSecretContainer]): this.type = {
var rawToken = ""
var scope = ""
var key = ""
try {
// Token secrets not supported in local testing
if (tokenSecret.nonEmpty && !_isLocalTesting) { // not local testing and secret passed
_workspaceUrl = dbutils.notebook.getContext().apiUrl.get
_cloudProvider = if (_workspaceUrl.toLowerCase().contains("azure")) "azure" else "aws"
scope = tokenSecret.get.scope
key = tokenSecret.get.key
rawToken = dbutils.secrets.get(scope, key)
val authMessage = s"Valid Secret Identified: Executing with token located in secret, $scope : $key"
logger.log(Level.INFO, authMessage)
rawToken = SecretTools(tokenSecret.get).getApiToken
_tokenType = "Secret"
} else {
if (_isLocalTesting) { // Local testing env vars
Expand All @@ -344,7 +338,7 @@ class Config() {
}
}
if (!rawToken.matches("^(dapi|dkea)[a-zA-Z0-9-]*$")) throw new BadConfigException(s"contents of secret " +
s"at scope:key $scope:$key is not in a valid format. Please validate the contents of your secret. It must be " +
s"is not in a valid format. Please validate the contents of your secret. It must be " +
s"a user access token. It should start with 'dapi' ")
setApiEnv(ApiEnv(isLocalTesting, workspaceURL, rawToken, packageVersion))
this
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package com.databricks.labs.overwatch.utils

import com.databricks.dbutils_v1.DBUtilsHolder.dbutils
import org.apache.log4j.{Level, Logger}

/**
* SecretTools handles common functionality related to working with secrets:
* 1) Get Databricks API token stored in specified secret
* 2) Normalize secret structure to be stored at Delta table pipeline_report under inputConfig.tokenSecret nested struct column
* There are two secret types available now - AWS Secrets Manager, Databricks secrets
*/
trait SecretTools[T <: TokenSecretContainer] {
def getApiToken : String
def getTargetTableStruct: TokenSecret
}

object SecretTools {
private val logger: Logger = Logger.getLogger(this.getClass)
type DatabricksTokenSecret = TokenSecret

private class DatabricksSecretTools(tokenSecret : DatabricksTokenSecret) extends SecretTools[DatabricksTokenSecret] {
override def getApiToken: String = {
val scope = tokenSecret.scope
val key = tokenSecret.key
val authMessage = s"Executing with Databricks token located in secret, scope=$scope : key=$key"
logger.log(Level.INFO, authMessage)
dbutils.secrets.get(scope, key)
}

override def getTargetTableStruct: TokenSecret = {
TokenSecret(tokenSecret.scope,tokenSecret.key)
}
}

private class AwsSecretTools(tokenSecret : AwsTokenSecret) extends SecretTools[AwsTokenSecret] {
override def getApiToken: String = {
val secretId = tokenSecret.secretId
val region = tokenSecret.region
val tokenKey = tokenSecret.tokenKey
val authMessage = s"Executing with AWS token located in secret, secretId=$secretId : region=$region : tokenKey=$tokenKey"
logger.log(Level.INFO, authMessage)
AwsSecrets.readApiToken(secretId, region, tokenSecret.tokenKey)
}

override def getTargetTableStruct: TokenSecret = {
TokenSecret(tokenSecret.region, tokenSecret.secretId)
}
}

def apply(secretSource: TokenSecretContainer): SecretTools[_] = {
secretSource match {
case x: AwsTokenSecret => new AwsSecretTools(x)
case y: DatabricksTokenSecret => new DatabricksSecretTools(y)
case _ => throw new IllegalArgumentException(s"${secretSource.toString} not implemented")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ case class SparkDetail()

case class GangliaDetail()

case class TokenSecret(scope: String, key: String)
abstract class TokenSecretContainer extends Product with Serializable
case class TokenSecret(scope: String, key: String) extends TokenSecretContainer
case class AwsTokenSecret(secretId: String, region: String, tokenKey: Option[String] = None) extends TokenSecretContainer

case class DataTarget(databaseName: Option[String], databaseLocation: Option[String], etlDataPathPrefix: Option[String],
consumerDatabaseName: Option[String] = None, consumerDatabaseLocation: Option[String] = None)
Expand Down Expand Up @@ -75,7 +77,7 @@ case class AuditLogConfig(
case class IntelligentScaling(enabled: Boolean = false, minimumCores: Int = 4, maximumCores: Int = 512, coeff: Double = 1.0)

case class OverwatchParams(auditLogConfig: AuditLogConfig,
tokenSecret: Option[TokenSecret] = None,
tokenSecret: Option[TokenSecretContainer] = None,
dataTarget: Option[DataTarget] = None,
badRecordsPath: Option[String] = None,
overwatchScope: Option[Seq[String]] = None,
Expand Down Expand Up @@ -349,9 +351,15 @@ object OverwatchEncoders {
implicit def overwatchScope: org.apache.spark.sql.Encoder[OverwatchScope] =
org.apache.spark.sql.Encoders.kryo[OverwatchScope]

/*
implicit def tokenSecret: org.apache.spark.sql.Encoder[TokenSecret] =
org.apache.spark.sql.Encoders.kryo[TokenSecret]

implicit def tokenSecretContainer: org.apache.spark.sql.Encoder[TokenSecretContainer] =
org.apache.spark.sql.Encoders.kryo[TokenSecretContainer]

*/

implicit def dataTarget: org.apache.spark.sql.Encoder[DataTarget] =
org.apache.spark.sql.Encoders.kryo[DataTarget]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import com.fasterxml.jackson.module.scala.DefaultScalaModule
import io.delta.tables.DeltaTable
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.hadoop.conf._
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.log4j.{Level, Logger}
import org.apache.spark.sql.functions._
import org.apache.spark.util.SerializableConfiguration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,34 @@ class ParamDeserializerTest extends AnyFunSpec {

describe("ParamDeserializer") {

val paramModule: SimpleModule = new SimpleModule()
.addDeserializer(classOf[OverwatchParams], new ParamDeserializer)
val mapper: ObjectMapper with ScalaObjectMapper = (new ObjectMapper() with ScalaObjectMapper)
.registerModule(DefaultScalaModule)
.registerModule(paramModule)
.asInstanceOf[ObjectMapper with ScalaObjectMapper]

it("should decode passed token string as AWS secrets") {
val AWSsecrets = """
|{"tokenSecret":{"secretId":"overwatch","region":"us-east-2"}}
|""".stripMargin


val expected = Some(AwsTokenSecret("overwatch", "us-east-2"))
val parsed = mapper.readValue[OverwatchParams](AWSsecrets)
assertResult(expected)(parsed.tokenSecret)
}

it("should decode passed token string as Databricks secrets") {
val Databrickssecrets = """
|{"tokenSecret":{"scope":"overwatch", "key":"test-key"}}
|""".stripMargin

val expected = Some(TokenSecret("overwatch", "test-key"))
val parsed = mapper.readValue[OverwatchParams](Databrickssecrets)
assertResult(expected)(parsed.tokenSecret)
}

it("should decode incomplete parameters") {
val incomplete = """
|{"auditLogConfig":{"azureAuditLogEventhubConfig":{"connectionString":"test","eventHubName":"overwatch-evhub",
Expand All @@ -24,13 +52,6 @@ class ParamDeserializerTest extends AnyFunSpec {
|"workspace_name":"myTestWorkspace", "externalizeOptimizations":"false"}
|""".stripMargin

val paramModule: SimpleModule = new SimpleModule()
.addDeserializer(classOf[OverwatchParams], new ParamDeserializer)
val mapper: ObjectMapper with ScalaObjectMapper = (new ObjectMapper() with ScalaObjectMapper)
.registerModule(DefaultScalaModule)
.registerModule(paramModule)
.asInstanceOf[ObjectMapper with ScalaObjectMapper]

val expected = OverwatchParams(
AuditLogConfig(
azureAuditLogEventhubConfig = Some(AzureAuditLogEventhubConfig(
Expand Down