Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def create_default_project_for_workspaces(session: Connection):
for workspace in workspaces:
# Create a new default project for each workspace
get_or_create_workspace_default_project(
session=session,
workspace=workspace, # type: ignore
session=session, workspace=workspace # type: ignore
)

# Commit the changes for the current batch
Expand Down
10 changes: 4 additions & 6 deletions api/ee/databases/postgres/migrations/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@ async def get_current_migration_head_from_db(engine: AsyncEngine):

async with engine.connect() as connection:
try:
result = await connection.execute(
text("SELECT version_num FROM alembic_version")
) # type: ignore
result = await connection.execute(text("SELECT version_num FROM alembic_version")) # type: ignore
except (asyncpg.exceptions.UndefinedTableError, ProgrammingError):
# Note: If the alembic_version table does not exist, it will result in raising an UndefinedTableError exception.
# We need to suppress the error and return a list with the alembic_version table name to inform the user that there is a pending migration \
Expand All @@ -85,9 +83,9 @@ async def get_current_migration_head_from_db(engine: AsyncEngine):
return "alembic_version"

migration_heads = [row[0] for row in result.fetchall()]
assert len(migration_heads) == 1, (
"There can only be one migration head stored in the database."
)
assert (
len(migration_heads) == 1
), "There can only be one migration head stored in the database."
return migration_heads[0]


Expand Down
10 changes: 4 additions & 6 deletions api/ee/databases/postgres/migrations/tracing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,7 @@ async def get_current_migration_head_from_db(engine: AsyncEngine):

async with engine.connect() as connection:
try:
result = await connection.execute(
text("SELECT version_num FROM alembic_version")
) # type: ignore
result = await connection.execute(text("SELECT version_num FROM alembic_version")) # type: ignore
except (asyncpg.exceptions.UndefinedTableError, ProgrammingError):
# Note: If the alembic_version table does not exist, it will result in raising an UndefinedTableError exception.
# We need to suppress the error and return a list with the alembic_version table name to inform the user that there is a pending migration \
Expand All @@ -78,9 +76,9 @@ async def get_current_migration_head_from_db(engine: AsyncEngine):
return "alembic_version"

migration_heads = [row[0] for row in result.fetchall()]
assert len(migration_heads) == 1, (
"There can only be one migration head stored in the database."
)
assert (
len(migration_heads) == 1
), "There can only be one migration head stored in the database."
return migration_heads[0]


Expand Down
4 changes: 2 additions & 2 deletions api/ee/docker/Dockerfile.dev
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ RUN cat -A /etc/cron.d/meters-cron
RUN chmod +x /meters.sh \
&& chmod 0644 /etc/cron.d/meters-cron

COPY ./oss/src/crons/queries.sh /queries.sh
COPY ./oss/src/crons/queries.txt /etc/cron.d/queries-cron
COPY ./ee/src/crons/queries.sh /queries.sh
COPY ./ee/src/crons/queries.txt /etc/cron.d/queries-cron
RUN sed -i -e '$a\' /etc/cron.d/queries-cron
RUN cat -A /etc/cron.d/queries-cron

Expand Down
4 changes: 2 additions & 2 deletions api/ee/docker/Dockerfile.gh
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ RUN cat -A /etc/cron.d/meters-cron
RUN chmod +x /meters.sh \
&& chmod 0644 /etc/cron.d/meters-cron

COPY ./oss/src/crons/queries.sh /queries.sh
COPY ./oss/src/crons/queries.txt /etc/cron.d/queries-cron
COPY ./ee/src/crons/queries.sh /queries.sh
COPY ./ee/src/crons/queries.txt /etc/cron.d/queries-cron
RUN sed -i -e '$a\' /etc/cron.d/queries-cron
RUN cat -A /etc/cron.d/queries-cron

Expand Down
93 changes: 43 additions & 50 deletions api/ee/src/apis/fastapi/billing/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

stripe.api_key = environ.get("STRIPE_API_KEY")

MAC_ADDRESS = ":".join(f"{(getnode() >> ele) & 0xFF:02x}" for ele in range(40, -1, -8))
MAC_ADDRESS = ":".join(f"{(getnode() >> ele) & 0xff:02x}" for ele in range(40, -1, -8))
STRIPE_WEBHOOK_SECRET = environ.get("STRIPE_WEBHOOK_SECRET")
STRIPE_TARGET = environ.get("STRIPE_TARGET") or MAC_ADDRESS
AGENTA_PRICING = loads(environ.get("AGENTA_PRICING") or "{}")
Expand Down Expand Up @@ -824,13 +824,12 @@ async def create_portal_user_route(
self,
request: Request,
):
if is_ee():
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.EDIT_BILLING,
):
return FORBIDDEN_RESPONSE
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.EDIT_BILLING,
):
return FORBIDDEN_RESPONSE

return await self.create_portal(
organization_id=request.state.organization_id,
Expand All @@ -852,13 +851,12 @@ async def create_checkout_user_route(
plan: Plan = Query(...),
success_url: str = Query(...), # find a way to make this optional or moot
):
if is_ee():
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.EDIT_BILLING,
):
return FORBIDDEN_RESPONSE
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.EDIT_BILLING,
):
return FORBIDDEN_RESPONSE

return await self.create_checkout(
organization_id=request.state.organization_id,
Expand All @@ -884,13 +882,12 @@ async def fetch_plan_user_route(
self,
request: Request,
):
if is_ee():
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.VIEW_BILLING,
):
return FORBIDDEN_RESPONSE
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.VIEW_BILLING,
):
return FORBIDDEN_RESPONSE

return await self.fetch_plans(
organization_id=request.state.organization_id,
Expand All @@ -902,13 +899,12 @@ async def switch_plans_user_route(
request: Request,
plan: Plan = Query(...),
):
if is_ee():
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.EDIT_BILLING,
):
return FORBIDDEN_RESPONSE
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.EDIT_BILLING,
):
return FORBIDDEN_RESPONSE

return await self.switch_plans(
organization_id=request.state.organization_id,
Expand All @@ -931,13 +927,12 @@ async def fetch_subscription_user_route(
self,
request: Request,
):
if is_ee():
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.VIEW_BILLING,
):
return FORBIDDEN_RESPONSE
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.VIEW_BILLING,
):
return FORBIDDEN_RESPONSE

return await self.fetch_subscription(
organization_id=request.state.organization_id,
Expand All @@ -948,13 +943,12 @@ async def cancel_subscription_user_route(
self,
request: Request,
):
if is_ee():
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.EDIT_BILLING,
):
return FORBIDDEN_RESPONSE
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.EDIT_BILLING,
):
return FORBIDDEN_RESPONSE

return await self.cancel_subscription(
organization_id=request.state.organization_id,
Expand All @@ -974,13 +968,12 @@ async def fetch_usage_user_route(
self,
request: Request,
):
if is_ee():
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.VIEW_BILLING,
):
return FORBIDDEN_RESPONSE
if not await check_action_access(
user_uid=request.state.user_id,
project_id=request.state.project_id,
permission=Permission.VIEW_BILLING,
):
return FORBIDDEN_RESPONSE

return await self.fetch_usage(
organization_id=request.state.organization_id,
Expand Down
2 changes: 1 addition & 1 deletion api/ee/src/core/subscriptions/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

stripe.api_key = environ.get("STRIPE_SECRET_KEY")

MAC_ADDRESS = ":".join(f"{(getnode() >> ele) & 0xFF:02x}" for ele in range(40, -1, -8))
MAC_ADDRESS = ":".join(f"{(getnode() >> ele) & 0xff:02x}" for ele in range(40, -1, -8))
STRIPE_TARGET = environ.get("STRIPE_TARGET") or MAC_ADDRESS
AGENTA_PRICING = loads(environ.get("AGENTA_PRICING") or "{}")

Expand Down
File renamed without changes.
File renamed without changes.
25 changes: 24 additions & 1 deletion api/ee/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@

from oss.src.utils.logging import get_module_logger

from ee.src.routers import workspace_router, organization_router
from ee.src.routers import (
workspace_router,
organization_router,
evaluation_router,
human_evaluation_router,
)

from ee.src.dbs.postgres.meters.dao import MetersDAO
from ee.src.dbs.postgres.subscriptions.dao import SubscriptionsDAO
Expand Down Expand Up @@ -66,11 +71,29 @@ def extend_main(app: FastAPI):
prefix="/workspaces",
)

app.include_router(
evaluation_router.router,
prefix="/evaluations",
tags=["Evaluations"],
)

app.include_router(
human_evaluation_router.router,
prefix="/human-evaluations",
tags=["Human-Evaluations"],
)

# --------------------------------------------------------------------------

return app


def load_tasks():
import ee.src.tasks.evaluations.live
import ee.src.tasks.evaluations.legacy
import ee.src.tasks.evaluations.batch


def extend_app_schema(app: FastAPI):
app.openapi()["info"]["title"] = "Agenta API"
app.openapi()["info"]["description"] = "Agenta API"
Expand Down
Loading
Loading