Skip to content

Commit 60401b7

Browse files
committed
attack link to org and feedback save task - first pass
1 parent d5ddec5 commit 60401b7

File tree

3 files changed

+55
-4
lines changed

3 files changed

+55
-4
lines changed

src/baskerville/db/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ class Attack(Base, SerializableMixin):
212212

213213
id = Column(BigInteger, primary_key=True)
214214
id_misp = Column(BigInteger)
215+
uuid_org = Column(TEXT())
215216
date = Column(DateTime(timezone=True))
216217
start = Column(DateTime(timezone=True))
217218
stop = Column(DateTime(timezone=True))
@@ -224,6 +225,7 @@ class Attack(Base, SerializableMixin):
224225
sync_stop = Column(DateTime(timezone=True))
225226
processed = Column(Integer)
226227
notes = Column(TEXT)
228+
progress_report = Column(TEXT)
227229
analysis_notebook = Column(TEXT)
228230

229231
request_sets = relationship(
@@ -234,6 +236,10 @@ class Attack(Base, SerializableMixin):
234236
'Attribute', secondary='attribute_attack_link',
235237
back_populates='attacks'
236238
)
239+
organization = relationship(
240+
'Organization',
241+
primaryjoin='foreign(Attack.uuid_org) == remote(Organization.uuid)'
242+
)
237243

238244

239245
class Attribute(Base, SerializableMixin):

src/baskerville/models/pipeline_tasks/feedback_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from baskerville.db.dashboard_models import Feedback
77
from baskerville.models.pipeline_tasks.tasks_base import Task
88
from baskerville.models.config import BaskervilleConfig
9-
from baskerville.models.pipeline_tasks.tasks import GetDataKafka, Save
9+
from baskerville.models.pipeline_tasks.tasks import GetDataKafka, SaveFeedback
1010

1111

1212
def set_up_feedback_pipeline(config: BaskervilleConfig):
@@ -21,10 +21,10 @@ def set_up_feedback_pipeline(config: BaskervilleConfig):
2121
GetDataKafka(
2222
config,
2323
steps=[
24-
Save(
24+
SaveFeedback(
2525
config,
2626
table_model=Feedback,
27-
not_common=()
27+
not_common=('feedback_context', 'progress_count')
2828
),
2929
]),
3030
]

src/baskerville/models/pipeline_tasks/tasks.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,10 @@
2020
from pyspark.streaming import StreamingContext
2121
from functools import reduce
2222
from pyspark.sql import DataFrame
23+
from sqlalchemy.exc import SQLAlchemyError
2324

2425
from baskerville.db import get_jdbc_url
25-
from baskerville.db.models import RequestSet, Model
26+
from baskerville.db.models import RequestSet, Model, Attack
2627
from baskerville.models.banjax_report_consumer import BanjaxReportConsumer
2728
from baskerville.models.ip_cache import IPCache
2829
from baskerville.models.metrics.registry import metrics_registry
@@ -933,6 +934,50 @@ def run(self):
933934
return self.df
934935

935936

937+
class SaveFeedback(Save):
938+
def upsert_attack(self):
939+
new_ = False
940+
success = False
941+
try:
942+
feedback_context = self.df.select('feedback_context').collect()
943+
attack = self.db_tools.session.query(Attack).filter_by(
944+
uuid=feedback_context.uuid
945+
).filter_by(
946+
start=feedback_context.start
947+
).filter_by(stop=feedback_context.stop).first()
948+
if not attack:
949+
attack = Attack()
950+
new_ = True
951+
attack.uuid_org = feedback_context.uuid
952+
attack.start = feedback_context.start
953+
attack.stop = feedback_context.stop
954+
attack.target = feedback_context.target
955+
attack.ip_count = feedback_context.ip_count
956+
attack.notes = feedback_context.notes
957+
attack.progress_report = feedback_context.progress_report
958+
if new_:
959+
self.db_tools.session.add(attack)
960+
self.db_tools.session.commit()
961+
success = True
962+
except SQLAlchemyError as sqle:
963+
traceback.print_exc()
964+
self.db_tools.session.rollback()
965+
success = False
966+
self.logger.error(str(sqle))
967+
# todo: what should the handling be?
968+
except Exception as e:
969+
traceback.print_exc()
970+
success = False
971+
self.logger.error(str(e))
972+
# todo: what should the handling be?
973+
return success
974+
975+
def prepare_to_save(self):
976+
success = self.upsert_attack()
977+
if success:
978+
super(SaveFeedback, self).prepare_to_save()
979+
980+
936981
class RefreshCache(CacheTask):
937982
def run(self):
938983
self.service_provider.refresh_cache(self.df)

0 commit comments

Comments
 (0)