diff --git a/requirements.txt b/requirements.txt index 148cacd7..bfb80a13 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,5 @@ -jinja2==2.10 +jinja2>=2.10 # numpy==1.14.3 -PyYAML==3.12 cryptography==2.2.2 python-geoip==1.2 python-geoip-geolite2==2015.303 @@ -33,3 +32,4 @@ kafka-python==2.0.1 cachetools==4.1.1 spark-testing-base==0.10.0 passlib==1.7.4 +pyaml-env==1.1.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 4b1adc26..eb4897aa 100644 --- a/setup.py +++ b/setup.py @@ -42,6 +42,7 @@ 'baskerville.features', 'baskerville.util', 'baskerville.models', + 'baskerville.models.pipeline_tasks', 'baskerville.models.metrics', 'baskerville.db', 'baskerville.spark', diff --git a/src/baskerville/db/__init__.py b/src/baskerville/db/__init__.py index 6d33588e..02a6cc2b 100644 --- a/src/baskerville/db/__init__.py +++ b/src/baskerville/db/__init__.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import traceback import uuid from baskerville.db.data_partitioning import get_temporal_partitions @@ -158,13 +159,16 @@ def set_up_db(conf, create=True, partition=True): isolation_level='AUTOCOMMIT', **conf.get('db_conn_args', {}) ).connect() as connection: - connection.execute(f'CREATE DATABASE {conf.get("name")} if not exists') + connection.execute(f'CREATE DATABASE {conf.get("name")}') connection.execute( 'CREATE CAST (VARCHAR AS JSON) ' 'WITHOUT FUNCTION AS IMPLICIT' ) - except ProgrammingError: - pass + except ProgrammingError as e: + if 'already exists' in str(e): + pass + else: + raise e engine = create_engine( get_db_connection_str(conf), @@ -184,13 +188,12 @@ def set_up_db(conf, create=True, partition=True): if not database_exists(engine.url): create_database(engine.url) - Session = scoped_session(sessionmaker(bind=engine)) + session = scoped_session(sessionmaker(bind=engine))() Base.metadata.create_all(bind=engine) - # session = Session() - if Session.query(Organization).count() == 0 and \ - Session.query(User).count() == 0 and \ - Session.query(UserCategory).count() == 0: + if session.query(Organization).count() == 0 and \ + session.query(User).count() == 0 and \ + session.query(UserCategory).count() == 0: try: organization = Organization() organization.uuid = 'test' @@ -204,15 +207,13 @@ def set_up_db(conf, create=True, partition=True): user.category = category user.email = 'email' user.name = 'default_user' - Session.add(organization) - Session.add(user) - Session.commit() + session.add(organization) + session.add(user) + session.commit() except Exception as err: - Session.rollback() + session.rollback() raise err - - # create data partition maintenance_conf = conf.get('maintenance') if conf.get('type') == 'postgres' \ @@ -220,7 +221,7 @@ def set_up_db(conf, create=True, partition=True): and maintenance_conf['data_partition'] \ and create \ and partition: - Session.execute(text(get_temporal_partitions(maintenance_conf))) + session.execute(text(get_temporal_partitions(maintenance_conf))) print('Partitioning done...') - return Session, engine + return session, engine diff --git a/src/baskerville/db/dashboard_models.py b/src/baskerville/db/dashboard_models.py index e01eee2e..8a88f520 100644 --- a/src/baskerville/db/dashboard_models.py +++ b/src/baskerville/db/dashboard_models.py @@ -88,6 +88,9 @@ class FeedbackContext(Base, SerializableMixin): __tablename__ = 'feedback_contexts' id = Column(BigInteger, primary_key=True, autoincrement=True, unique=True) uuid_organization = Column(String(300), nullable=False) + # id_user should be used only on the user module side, + # it is not communicated back and forth with the clearinghouse + id_user = Column(BigInteger(), ForeignKey('users.id'), nullable=True) reason = Column(Enum(FeedbackContextTypeEnum)) reason_descr = Column(TEXT()) start = Column(DateTime(timezone=True)) diff --git a/src/baskerville/main.py b/src/baskerville/main.py index ae2bc61b..d37d6514 100644 --- a/src/baskerville/main.py +++ b/src/baskerville/main.py @@ -18,16 +18,17 @@ from dateutil.tz import tzutc from prometheus_client import start_http_server +from pyaml_env import parse_config from baskerville import src_dir from baskerville.db import set_up_db from baskerville.db.models import Model +from baskerville.models.anomaly_model_sklearn import AnomalyModelSklearn from baskerville.models.config import DatabaseConfig from baskerville.models.engine import BaskervilleAnalyticsEngine from baskerville.simulation.real_timeish_simulation import simulation from baskerville.util.git_helpers import git_clone -from baskerville.util.helpers import get_logger, parse_config, \ - get_default_data_path +from baskerville.util.helpers import get_logger, get_default_data_path PROCESS_LIST = [] @@ -59,14 +60,14 @@ def run_simulation(conf, spark=None): simulation_process = Process( name='SimulationThread', target=simulation, - args=[ + args=( engine_conf.simulation.log_file, timedelta(seconds=engine_conf.time_bucket), - ], + ), kwargs={ 'topic_name': kafka_conf['consume_topic'], 'sleep': engine_conf.simulation.sleep, - 'kafka_url': kafka_conf.connection['bootstrap_servers']['bootstrap_servers'], + 'kafka_url': kafka_conf.connection['bootstrap_servers'], 'zookeeper_url': kafka_conf.zookeeper, 'verbose': engine_conf.simulation.verbose, 'spark': spark, @@ -76,6 +77,33 @@ def run_simulation(conf, spark=None): print('Set up Simulation...') +def add_model_to_database(database_config): + """ + Load the test model and save it in the database + :param dict[str, T] database_config: + :return: + """ + global logger + path = os.path.join(get_default_data_path(), 'samples', 'models', 'AnomalyModel') + logger.info(f'Loading test model from: {path}') + model = AnomalyModelSklearn() + model.load(path=path) + + db_cfg = DatabaseConfig(database_config).validate() + session, _ = set_up_db(db_cfg.__dict__, partition=False) + + db_model = Model() + db_model.algorithm = 'baskerville.models.anomaly_model_sklearn.AnomalyModelSklearn' + db_model.created_at = datetime.now(tz=tzutc()) + db_model.parameters = json.dumps(model.get_params()) + db_model.classifier = bytearray(path.encode('utf8')) + + # save to db + session.add(db_model) + session.commit() + session.close() + + def main(): """ Baskerville commandline arguments @@ -155,7 +183,12 @@ def main(): raise RuntimeError('Cannot start exporter without metrics config') port = baskerville_engine.config.engine.metrics.port start_http_server(port) - logger.info(f'Starting Baskerville Exporter at http://localhost:{port}') + logger.info(f'Starting Baskerville Exporter at ' + f'http://localhost:{port}') + + # populate with test data if specified + if args.test_model: + add_model_to_database(conf['database']) for p in PROCESS_LIST[::-1]: print(f"{p.name} starting...") diff --git a/src/baskerville/models/banjax_report_consumer.py b/src/baskerville/models/banjax_report_consumer.py index ccc87480..74a15669 100644 --- a/src/baskerville/models/banjax_report_consumer.py +++ b/src/baskerville/models/banjax_report_consumer.py @@ -10,7 +10,7 @@ from baskerville.db import set_up_db from baskerville.models.config import KafkaConfig from baskerville.models.ip_cache import IPCache -from baskerville.util.helpers import parse_config +from pyaml_env import parse_config import argparse import os from baskerville import src_dir diff --git a/src/baskerville/models/config.py b/src/baskerville/models/config.py index 2b85283d..d7f0cd6b 100644 --- a/src/baskerville/models/config.py +++ b/src/baskerville/models/config.py @@ -875,7 +875,6 @@ def validate(self): self.shuffle_partitions = int(self.shuffle_partitions) except ValueError: self.add_error(ConfigError( - 'Spark shuffle_partitions should be an integer', 'Spark shuffle_partitions should be an integer', ['shuffle_partitions'], )) diff --git a/src/baskerville/models/pipeline_tasks/feedback_pipeline.py b/src/baskerville/models/pipeline_tasks/feedback_pipeline.py index 41accd6c..aa496c98 100644 --- a/src/baskerville/models/pipeline_tasks/feedback_pipeline.py +++ b/src/baskerville/models/pipeline_tasks/feedback_pipeline.py @@ -32,9 +32,11 @@ def set_up_feedback_pipeline(config: BaskervilleConfig): ), SendToKafka( config, - columns=('uuid_organization', 'id_context', 'success'), - topic=None, - client_topic='feedback', + ('uuid_organization', 'id_context', 'success'), + 'feedback', + cmd='feedback_center', + cc_to_client=True, + client_only=True ) ]), ] diff --git a/src/baskerville/models/pipeline_tasks/service_provider.py b/src/baskerville/models/pipeline_tasks/service_provider.py index 54435988..3e3212bf 100644 --- a/src/baskerville/models/pipeline_tasks/service_provider.py +++ b/src/baskerville/models/pipeline_tasks/service_provider.py @@ -35,6 +35,8 @@ def __init__(self, config: BaskervilleConfig): self.config = config self.start_time = datetime.datetime.utcnow() self.runtime = None + self.user = None + self.org = None self.request_set_cache = None self.spark = None self.tools = None @@ -90,22 +92,22 @@ def refresh_model(self): def create_runtime(self): from baskerville.db.dashboard_models import User, Organization - org = self.tools.session.query(Organization).filter_by( + self.org = self.tools.session.query(Organization).filter_by( uuid=self.config.user_details.organization_uuid ).first() - if not org: + if not self.org: raise ValueError(f'No such organization.') - user = self.tools.session.query(User).filter_by( + self.user = self.tools.session.query(User).filter_by( username=self.config.user_details.username).filter_by( - id_organization=org.id + id_organization=self.org.id ).first() - if not user: + if not self.user: raise ValueError(f'No such user.') self.runtime = self.tools.create_runtime( start=self.start_time, conf=self.config.engine, - id_user=user.id + id_user=self.user.id ) self.logger.info(f'Created runtime {self.runtime.id}') diff --git a/src/baskerville/models/pipeline_tasks/tasks.py b/src/baskerville/models/pipeline_tasks/tasks.py index 568553f1..0c7dba51 100644 --- a/src/baskerville/models/pipeline_tasks/tasks.py +++ b/src/baskerville/models/pipeline_tasks/tasks.py @@ -43,11 +43,11 @@ get_feedback_context_schema, get_features_schema from kafka import KafkaProducer from dateutil.tz import tzutc +from pyaml_env import parse_config # broadcasts from baskerville.util.enums import LabelEnum -from baskerville.util.helpers import instantiate_from_str, get_model_path, \ - parse_config +from baskerville.util.helpers import instantiate_from_str, get_model_path from baskerville.util.helpers import instantiate_from_str, get_model_path from baskerville.util.kafka_helpers import send_to_kafka, read_from_kafka_from_the_beginning from baskerville.util.mail_sender import MailSender @@ -1430,16 +1430,20 @@ def __init__( self, config: BaskervilleConfig, columns, - topic=None, + topic, + cmd='prediction_center', cc_to_client=False, client_topic=None, + client_only=True, send_to_clearing_house=False, steps: list = (), ): super().__init__(config, steps) self.columns = columns self.topic = topic + self.cmd = cmd self.cc_to_client = cc_to_client + self.client_only = client_only self.client_topic = client_topic if send_to_clearing_house: @@ -1771,7 +1775,8 @@ def classify_anomalies(self): self.df = self.df.withColumn( 'prediction', F.when(F.col('score') > self.config.engine.anomaly_threshold, - F.lit(1.0)).otherwise(F.lit(0.)) + F.lit(LabelEnum.malicious.value) + ).otherwise(F.lit(LabelEnum.benign.value)) ) def detect_low_rate_attack(self): @@ -1834,7 +1839,7 @@ class Challenge(Task): def __init__( self, config: BaskervilleConfig, steps=(), - attack_cols=('prediction', 'low_rate_attack') + attack_cols=('prediction', 'attack_prediction', 'low_rate_attack') ): super().__init__(config, steps) self.attack_cols = attack_cols diff --git a/src/baskerville/spark/__init__.py b/src/baskerville/spark/__init__.py index a6a08f8d..2a191f25 100644 --- a/src/baskerville/spark/__init__.py +++ b/src/baskerville/spark/__init__.py @@ -30,6 +30,14 @@ def get_or_create_spark_session(spark_conf): if spark_conf.redis_password: conf.set('spark.redis.auth', spark_conf.redis_password) + if spark_conf.spark_executor_instances: + conf.set('spark.executor.instances', + spark_conf.spark_executor_instances) + # conf.set('spark.streaming.dynamicAllocation.minExecutors', spark_conf.spark_executor_instances) + if spark_conf.spark_executor_cores: + conf.set('spark.executor.cores', spark_conf.spark_executor_cores) + if spark_conf.spark_executor_memory: + conf.set('spark.executor.memory', spark_conf.spark_executor_memory) # todo: https://stackoverflow.com/questions/ # 49672181/spark-streaming-dynamic-allocation-do-not-remove-executors-in-middle-of-window # https://medium.com/@pmatpadi/spark-streaming-dynamic-scaling-and-backpressure-in-action-6ebdbc782a69 @@ -119,7 +127,8 @@ def get_or_create_spark_session(spark_conf): if spark_conf.spark_executor_cores: conf.set('spark.executor.cores', spark_conf.spark_executor_cores) if spark_conf.spark_executor_instances: - conf.set('spark.executor.instances', spark_conf.spark_executor_instances) + conf.set('spark.executor.instances', + spark_conf.spark_executor_instances) if spark_conf.spark_executor_memory: conf.set('spark.executor.memory', spark_conf.spark_executor_memory) if spark_conf.serializer: @@ -188,8 +197,8 @@ def get_or_create_spark_session(spark_conf): if spark_conf.spark_kubernetes_executor_memoryOverhead: conf.set('spark.kubernetes.executor.memoryOverhead', spark_conf.spark_kubernetes_executor_memoryOverhead) - conf.set('spark.kubernetes.driver.pod.name', os.environ['MY_POD_NAME']) - conf.set('spark.driver.host', os.environ['MY_POD_IP']) + conf.set('spark.kubernetes.driver.pod.name', os.environ.get('MY_POD_NAME')) + conf.set('spark.driver.host', os.environ.get('MY_POD_IP', 'localhost')) conf.set('spark.driver.port', 20020) spark = SparkSession.builder \ diff --git a/src/baskerville/spark/udfs.py b/src/baskerville/spark/udfs.py index 8bca7cbe..24a90620 100644 --- a/src/baskerville/spark/udfs.py +++ b/src/baskerville/spark/udfs.py @@ -14,7 +14,6 @@ from pyspark.ml.linalg import Vectors, VectorUDT from pyspark.sql import functions as F from pyspark.sql import types as T -from tzwhere import tzwhere import numpy as np @@ -52,6 +51,7 @@ def compute_geotime(lat, lon, t, feature_default): # todo: how do latitude/longitude appear in raw ats record? feature_value = feature_default else: + from tzwhere import tzwhere tz = tzwhere.tzwhere() timezone_str = tz.tzNameAt(lat, lon) t = t.astimezone(pytz.timezone(timezone_str)) @@ -321,3 +321,4 @@ def get_msg_from_columns(row, columns): udf_bulk_update_request_sets = F.udf(bulk_update_request_sets, T.BooleanType()) udf_to_dense_vector = F.udf(lambda l: Vectors.dense(l), VectorUDT()) udf_add_to_dense_vector = F.udf(lambda features, arr: Vectors.dense(np.append(features, [v for v in arr])), VectorUDT()) +udf_send_to_kafka = F.udf(send_to_kafka, T.BooleanType()) diff --git a/src/baskerville/util/baskerville_tools.py b/src/baskerville/util/baskerville_tools.py index 95d997b2..88d348b5 100644 --- a/src/baskerville/util/baskerville_tools.py +++ b/src/baskerville/util/baskerville_tools.py @@ -28,7 +28,7 @@ def connect_to_db(self): bothound_tools instance and will be used to save data back to the db """ - self.session, self.engine = set_up_db(self.conf.__dict__) + self.session, self.engine = set_up_db(self.conf.__dict__, create=True) def create_runtime( self, diff --git a/src/baskerville/util/helpers.py b/src/baskerville/util/helpers.py index ea40d93d..845c5c5e 100644 --- a/src/baskerville/util/helpers.py +++ b/src/baskerville/util/helpers.py @@ -20,61 +20,6 @@ FOLDER_CACHE = 'cache' -def parse_config(path=None, data=None, tag='!ENV'): - """ - Load a yaml configuration file and resolve any environment variables - The environment variables must have !ENV before them and be in this format - to be parsed: ${VAR_NAME}. - E.g.: - - database: - host: !ENV ${HOST} - port: !ENV ${PORT} - engine: - log_path: !ENV '/var/${LOG_PATH}' - something_else: !ENV '${AWESOME_ENV_VAR}/var/${WOOHOO_VAR}' - - :param path: the path to the yaml file - :param data: the yaml data itself as a stream - :param tag: the tag to look for - :return: the dict configuration - :rtype: dict[str, T] - """ - # pattern for global vars - pattern = re.compile(r'.*?\${(\w+)}.*?') - loader = yaml.SafeLoader - loader.add_implicit_resolver(tag, pattern, None) - - def constructor_env_variables(loader, node): - """ - Extracts the environment variable from the node's value - :param yaml.Loader loader: the yaml loader - :param node: the current node in the yaml - :return: the parsed string that contains the value of the environment - variable - """ - value = loader.construct_scalar(node) - match = pattern.findall(value) - if match: - full_value = value - for g in match: - full_value = full_value.replace( - f'${{{g}}}', os.environ.get(g, g) - ) - return full_value - return value - - loader.add_constructor(tag, constructor_env_variables) - - if path: - with open(path) as conf_data: - return yaml.load(conf_data, Loader=loader) - elif data: - return yaml.load(data, Loader=loader) - else: - raise ValueError('Either a path or data should be defined as input') - - def get_logger( name, logging_level=logging.DEBUG, output_file='baskerville.log' ): diff --git a/src/baskerville/util/model_evaluation/evaluation.py b/src/baskerville/util/model_evaluation/evaluation.py index 44e5d39c..694f3bc1 100644 --- a/src/baskerville/util/model_evaluation/evaluation.py +++ b/src/baskerville/util/model_evaluation/evaluation.py @@ -7,6 +7,7 @@ from pyspark.ml.tuning import ParamGridBuilder, CrossValidator, \ CrossValidatorModel from pyspark.sql import functions as F +from pyaml_env import parse_config from baskerville.db.models import RequestSet from baskerville.models.anomaly_model import AnomalyModel @@ -14,7 +15,6 @@ from baskerville.spark import get_or_create_spark_session from baskerville.spark.helpers import load_df_from_table from baskerville.util.enums import LabelEnum -from baskerville.util.helpers import parse_config # # https://stackoverflow.com/questions/52847408/pyspark-extract-roc-curve diff --git a/src/baskerville/util/model_evaluation/evaluation_from_notebook.py b/src/baskerville/util/model_evaluation/evaluation_from_notebook.py index 7b303f94..0d16c4b2 100644 --- a/src/baskerville/util/model_evaluation/evaluation_from_notebook.py +++ b/src/baskerville/util/model_evaluation/evaluation_from_notebook.py @@ -7,10 +7,10 @@ from pyspark.sql import SparkSession from pyspark.sql.types import * from pyspark.mllib.evaluation import BinaryClassificationMetrics -from baskerville.db import get_jdbc_url +from pyaml_env import parse_config +from baskerville.db import get_jdbc_url from baskerville.db.models import Attack -from baskerville.util.helpers import parse_config from baskerville.models.config import BaskervilleConfig from baskerville.util.baskerville_tools import BaskervilleDBTools from baskerville.models.anomaly_model import AnomalyModel diff --git a/tests/unit/baskerville_tests/utils_tests/test_helpers.py b/tests/unit/baskerville_tests/utils_tests/test_helpers.py index ce8c1233..4fc15a1a 100644 --- a/tests/unit/baskerville_tests/utils_tests/test_helpers.py +++ b/tests/unit/baskerville_tests/utils_tests/test_helpers.py @@ -10,7 +10,8 @@ from datetime import datetime import pyspark -from baskerville.util.helpers import parse_config, periods_overlap, \ +from pyaml_env import parse_config +from baskerville.util.helpers import periods_overlap, \ instantiate_from_str, class_from_str