Skip to content

Commit 351b5b8

Browse files
Merge pull request #22 from DataKitchen/release/2.3.0
Release: 2.3.0
2 parents 44e791f + 678a040 commit 351b5b8

File tree

30 files changed

+252
-102
lines changed

30 files changed

+252
-102
lines changed

cli/entry_points/init.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ def initialize_database(self) -> None:
168168
description="This key is utilized by the product demo",
169169
)
170170

171-
action_args = {"from_address": "[email protected]", "recipients": [], "template": "NotifyTemplate"}
171+
action_args = {"recipients": [], "template": "NotifyTemplate"}
172172
action = Action.create(
173173
name="Send Email", action_impl="SEND_EMAIL", company=company, action_args=action_args
174174
)

common/actions/send_email_action.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020

2121
class SendEmailAction(BaseAction):
22-
required_arguments = {"from_address", "recipients", "template"}
22+
required_arguments = {"recipients", "template"}
2323
requires_action_template = True
2424

2525
def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> ActionResult:
@@ -29,8 +29,8 @@ def _run(self, event: EVENT_TYPE, rule: Rule, journey_id: Optional[UUID]) -> Act
2929
return ActionResult(False, None, e)
3030
try:
3131
response = EmailService.send_email(
32-
self.arguments["smtp_config"],
33-
self.arguments["from_address"],
32+
self.arguments.get("smtp_config", {}),
33+
self.arguments.get("from_address"),
3434
self.arguments["recipients"],
3535
self.arguments["template"],
3636
context,

common/email/email_service.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,13 @@ class EmailService:
1717
@staticmethod
1818
def send_email(
1919
smtp_config: dict,
20-
from_address: str,
20+
from_address: str | None,
2121
recipients: list[str],
2222
template_name: str,
2323
template_context_vars: Mapping,
2424
) -> dict:
2525
try:
26+
from_address = from_address or settings.SMTP["from_address"]
2627
content, subject = HandlebarsEmailRenderer.render(template_name, template_context_vars)
2728
context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
2829
message = MIMEMultipart("alternative")

common/entities/event.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from enum import Enum, IntEnum
44

5-
from peewee import ForeignKeyField, Node, fn
5+
from peewee import ForeignKeyField
66
from playhouse.mysql_ext import JSONField
77

88
from common.peewee_extensions.fields import EnumIntField, EnumStrField, UTCTimestampField
@@ -44,10 +44,6 @@ class EventEntity(BaseEntity):
4444
instance_set = ForeignKeyField(InstanceSet, null=True, backref="events", on_delete="SET NULL")
4545
v2_payload = JSONField(null=False)
4646

47-
@classmethod
48-
def timestamp_coalesce(cls) -> Node:
49-
return fn.COALESCE(cls.timestamp, cls.created_timestamp)
50-
5147
@property
5248
def components(self) -> list[Component]:
5349
return [self.component] if self.component else []

common/entity_services/company_service.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def get_organizations_with_rules(company_id: str, rules: ListRules) -> Page[Orga
2626

2727
@staticmethod
2828
def get_users_with_rules(company_id: str, rules: ListRules) -> Page[User]:
29-
query = User.select().join(Company).where(Company.id == company_id)
29+
query = User.select().join(Company).where(User.primary_company == company_id)
3030
return Page[User].get_paginated_results(query, User.name, rules)
3131

3232
@staticmethod

common/entity_services/event_service.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,12 @@ def get_events_with_rules(*, rules: ListRules, filters: ProjectEventFilters) ->
3737
filter_list: list[object] = [EventEntity.project << filters.project_ids]
3838

3939
if filters.instance_ids or filters.journey_ids:
40-
query = query.join(InstanceSet).join(InstancesInstanceSets).join(Instance).switch(EventEntity)
40+
instance_set_subquery = InstanceSet.select(InstanceSet.id).join(InstancesInstanceSets).join(Instance)
4141
if filters.journey_ids:
42-
filter_list.append(Instance.journey << filters.journey_ids)
42+
instance_set_subquery = instance_set_subquery.where(Instance.journey << filters.journey_ids)
4343
if filters.instance_ids:
44-
filter_list.append(Instance.id << filters.instance_ids)
44+
instance_set_subquery = instance_set_subquery.where(Instance.id << filters.instance_ids)
45+
query = query.where(EventEntity.instance_set.in_(instance_set_subquery))
4546
if filters.event_types:
4647
filter_list.append(EventEntity.type << filters.event_types)
4748
if filters.component_ids:
@@ -53,16 +54,12 @@ def get_events_with_rules(*, rules: ListRules, filters: ProjectEventFilters) ->
5354
if filters.task_ids:
5455
filter_list.append(EventEntity.task << filters.task_ids)
5556
if filters.date_range_start:
56-
filter_list.append(
57-
EventEntity.timestamp_coalesce() >= EventEntity.timestamp.db_value(filters.date_range_start)
58-
)
57+
filter_list.append(EventEntity.timestamp >= filters.date_range_start)
5958
if filters.date_range_end:
60-
filter_list.append(
61-
EventEntity.timestamp_coalesce() < EventEntity.timestamp.db_value(filters.date_range_end)
62-
)
59+
filter_list.append(EventEntity.timestamp < filters.date_range_end)
6360

6461
query = query.where(*filter_list)
65-
page = Page[EventEntity].get_paginated_results(query, EventEntity.timestamp_coalesce(), rules)
62+
page = Page[EventEntity].get_paginated_results(query, EventEntity.timestamp, rules)
6663

6764
# Using a single query to fetch the Instance and Journey data
6865
instance_set_ids = {e.instance_set_id for e in page.results}

common/entity_services/helpers/list_rules.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,11 @@
99

1010
from marshmallow import EXCLUDE, Schema
1111
from marshmallow.fields import Enum, Int
12-
from peewee import Field, Ordering, Select
12+
from peewee import JOIN, Field, Ordering, Select
1313
from werkzeug.datastructures import MultiDict
1414

15+
from common.entities import BaseEntity
16+
1517
DEFAULT_PAGE = 1
1618
DEFAULT_COUNT = 10
1719
T = TypeVar("T")
@@ -51,9 +53,24 @@ def __len__(self) -> int:
5153

5254
@classmethod
5355
def get_paginated_results(cls, query: Select, order_by: Field, list_rules: ListRules) -> Page[T]:
56+
model: BaseEntity = query.model
5457
ordering = list_rules.order_by_field(order_by)
55-
paginated_query = query.order_by(ordering).paginate(list_rules.page, list_rules.count)
56-
return cls(results=list(paginated_query), total=query.count())
58+
59+
paginated_subquery: Select = (
60+
model.select(model.id)
61+
.where(query._where)
62+
.order_by(ordering)
63+
.paginate(list_rules.page, list_rules.count)
64+
.alias("results")
65+
)
66+
67+
results_query = query.clone()
68+
results_query._where = None
69+
results_query = results_query.join(
70+
paginated_subquery, join_type=JOIN.INNER, on=(model.id == paginated_subquery.c.id)
71+
).order_by(ordering)
72+
73+
return cls(results=list(results_query), total=paginated_subquery.count(clear_limit=True))
5774

5875

5976
@dataclass

common/entity_services/journey_service.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,18 @@ def get_action_by_implementation(journey_id: UUID, action_impl: str) -> Optional
5151

5252
@staticmethod
5353
def get_components_with_rules(journey_id: str, rules: ListRules, filters: ComponentFilters) -> Page[Component]:
54-
query = JourneyDagEdge.journey_id == journey_id
54+
join_on = (JourneyDagEdge.journey_id == journey_id) & (
55+
(JourneyDagEdge.left == Component.id) | (JourneyDagEdge.right == Component.id)
56+
)
57+
query = Component.select(Component).join(JourneyDagEdge, on=join_on)
5558
if rules.search is not None:
56-
query &= Component.key ** f"%{rules.search}%"
59+
query = query.where(Component.key ** f"%{rules.search}%")
5760
if filters.component_types:
58-
query &= Component.type.in_(filters.component_types)
61+
query = query.where(Component.type.in_(filters.component_types))
5962
if filters.tools:
60-
query &= Component.tool.in_(filters.tools)
61-
join_on = (JourneyDagEdge.left == Component.id) | (JourneyDagEdge.right == Component.id)
62-
query = Component.select(Component).join(JourneyDagEdge, on=join_on).where(query).distinct()
63-
return Page[Component].get_paginated_results(query, Component.key, rules)
63+
query = query.where(Component.tool.in_(filters.tools))
64+
65+
return Page[Component].get_paginated_results(query.distinct(), Component.key, rules)
6466

6567
@staticmethod
6668
def get_upstream_nodes(journey: Journey, component_id: UUID) -> set:

common/entity_services/project_service.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,34 +55,36 @@ class ProjectService:
5555
def get_test_outcomes_with_rules(
5656
project: Project, rules: ListRules, filters: TestOutcomeItemFilters
5757
) -> Page[TestOutcome]:
58-
clauses = [TestOutcome.component.in_(project.components)]
58+
query = TestOutcome.select(TestOutcome).where(TestOutcome.component.in_(project.components))
5959
if rules.search is not None:
60-
clauses.append((TestOutcome.name ** f"%{rules.search}%") | (TestOutcome.description ** f"%{rules.search}%"))
60+
query = query.where(
61+
((TestOutcome.name ** f"%{rules.search}%") | (TestOutcome.description ** f"%{rules.search}%"))
62+
)
6163
if filters:
6264
if statuses := filters.statuses:
63-
clauses.append(TestOutcome.status.in_(statuses))
65+
query = query.where(TestOutcome.status.in_(statuses))
6466
if component_ids := filters.component_ids:
65-
clauses.append(TestOutcome.component << component_ids)
67+
query = query.where(TestOutcome.component << component_ids)
6668
if start_range_begin := filters.start_range_begin:
67-
clauses.append(TestOutcome.start_time >= start_range_begin)
69+
query = query.where(TestOutcome.start_time >= start_range_begin)
6870
if start_range_end := filters.start_range_end:
69-
clauses.append(TestOutcome.start_time < start_range_end)
71+
query = query.where(TestOutcome.start_time < start_range_end)
7072
if end_range_begin := filters.end_range_begin:
71-
clauses.append(TestOutcome.end_time >= end_range_begin)
73+
query = query.where(TestOutcome.end_time >= end_range_begin)
7274
if end_range_end := filters.end_range_end:
73-
clauses.append(TestOutcome.end_time < end_range_end)
75+
query = query.where(TestOutcome.end_time < end_range_end)
7476
if run_ids := filters.run_ids:
75-
clauses.append(TestOutcome.run.in_(run_ids))
77+
query = query.where(TestOutcome.run.in_(run_ids))
7678
if instance_ids := filters.instance_ids:
77-
clauses.append(InstancesInstanceSets.instance.in_(instance_ids))
79+
instance_set_subquery = (
80+
InstanceSet.select(InstanceSet.id)
81+
.join(InstancesInstanceSets)
82+
.where(InstancesInstanceSets.instance.in_(instance_ids))
83+
)
84+
query = query.where(TestOutcome.instance_set.in_(instance_set_subquery))
7885
if key := filters.key:
79-
clauses.append(TestOutcome.key == key)
86+
query = query.where(TestOutcome.key == key)
8087

81-
# If filtering on instance_ids we need to join the InstanceSet tables
82-
if filters and filters.instance_ids:
83-
query = TestOutcome.select(TestOutcome).join(InstanceSet).join(InstancesInstanceSets).where(*clauses)
84-
else:
85-
query = TestOutcome.select(TestOutcome).where(*clauses)
8688
return Page[TestOutcome].get_paginated_results(query, TestOutcome.start_time, rules)
8789

8890
@staticmethod

common/schemas/action_schemas.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class SMTPConfigSchema(Schema):
1616

1717

1818
class EmailActionArgsSchema(Schema):
19-
from_address = Str(validate=not_empty(max=255))
19+
from_address = Str(required=False, validate=not_empty(max=255))
2020
template = Str(validate=not_empty(max=255))
2121
recipients = List(Email(), required=True, validate=Length(max=50))
2222
smtp_config = Nested(SMTPConfigSchema(), required=False)

0 commit comments

Comments
 (0)