Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from pydantic import Field
from servicelib.aiohttp.requests_validation import parse_request_body_as
from servicelib.request_keys import RQT_USERID_KEY
from yarl import URL

from .._meta import API_VTAG as VTAG
from ..constants import RQ_PRODUCT_KEY
Expand Down Expand Up @@ -53,21 +52,19 @@ async def generate_invitation(request: web.Request):
extra_credits_in_usd=body.extra_credits_in_usd,
product=req_ctx.product_name,
),
request.url,
)
assert request.url.host # nosec
assert generated.product == req_ctx.product_name # nosec
assert generated.guest == body.guest # nosec

url = URL(f"{generated.invitation_url}")
invitation_link = request.url.with_path(url.path).with_fragment(url.raw_fragment)

invitation = InvitationGenerated(
product_name=generated.product,
issuer=generated.issuer,
guest=generated.guest,
trial_account_days=generated.trial_account_days,
extra_credits_in_usd=generated.extra_credits_in_usd,
created=generated.created,
invitation_link=f"{invitation_link}", # type: ignore[arg-type]
invitation_link=generated.invitation_url,
)
return envelope_json_response(invitation.model_dump(exclude_none=True))
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
ApiInvitationInputs,
)
from models_library.emails import LowerCaseEmailStr
from pydantic import AnyHttpUrl, TypeAdapter, ValidationError
from pydantic import AnyHttpUrl, HttpUrl, TypeAdapter, ValidationError
from yarl import URL

from ..groups.api import is_user_by_email_in_group
from ..products.models import Product
Expand Down Expand Up @@ -134,7 +135,9 @@ async def extract_invitation(


async def generate_invitation(
app: web.Application, params: ApiInvitationInputs
app: web.Application,
params: ApiInvitationInputs,
product_origin_url: URL,
) -> ApiInvitationContentAndLink:
"""
Raises:
Expand All @@ -145,4 +148,10 @@ async def generate_invitation(
invitation: ApiInvitationContentAndLink = await get_invitations_service_api(
app=app
).generate_invitation(params)

_normalized_url = URL(f"{invitation.invitation_url}")
invitation.invitation_url = HttpUrl(
f"{product_origin_url.with_path(_normalized_url.path).with_fragment(_normalized_url.raw_fragment)}"
)

return invitation
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,13 @@ async def approve_user_account(request: web.Request) -> web.Response:
guest=approval_data.email,
trial_account_days=approval_data.invitation.trial_account_days,
extra_credits_in_usd=approval_data.invitation.extra_credits_in_usd,
product=req_ctx.product_name,
)

invitation_result = await invitations_service.generate_invitation(
request.app, params=invitation_params
request.app,
params=invitation_params,
product_origin_url=request.url.origin(),
)

assert ( # nosec
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,22 @@ async def _register_account(invitation_url: HttpUrl, product_deployed: ProductNa
invitation_product_a = await generate_invitation(
client.app,
ApiInvitationInputs(issuer="PO", guest=guest_email, product=product_a),
product_origin_url=URL("http://product_a.com/some/path").origin(),
)
# 2. PO creates invitation for product B
invitation_product_b = await generate_invitation(
client.app,
ApiInvitationInputs(issuer="PO", guest=guest_email, product=product_b),
product_origin_url=URL("http://product_b.com/some/path").origin(),
)

# CAN register for product A in deploy of product A
assert invitation_product_a.invitation_url.host == "product_a.com"
response = await _register_account(invitation_product_a.invitation_url, product_a)
await assert_status(response, status.HTTP_200_OK)

# CANNOT register in product B in deploy of product A
assert invitation_product_b.invitation_url.host == "product_b.com"
response = await _register_account(invitation_product_b.invitation_url, product_a)
await assert_status(response, status.HTTP_409_CONFLICT)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -595,6 +595,7 @@ async def test_approve_user_account_with_full_invitation_details(
assert invitation_data["issuer"] == str(logged_user["id"])
assert invitation_data["trial_account_days"] == 30
assert invitation_data["extra_credits_in_usd"] == 100.0
assert invitation_data["product"] == product_name
assert "invitation_url" in invitation_data


Expand Down
Loading