|
1 | 1 | """Endpoint handlers for business logic.""" |
2 | 2 |
|
| 3 | +from typing import Any |
3 | 4 | from uuid import UUID |
4 | 5 |
|
5 | 6 | from fastapi import HTTPException |
@@ -725,41 +726,7 @@ async def _publish_to_marketplace( |
725 | 726 | username=marketplace.email, password=marketplace.password |
726 | 727 | ) |
727 | 728 |
|
728 | | - endpoint_type = ( |
729 | | - "model" if endpoint.model_id is not None else "data_source" |
730 | | - ) |
731 | | - policies = [ |
732 | | - { |
733 | | - "type": policy.policy_type, |
734 | | - "version": "1.0", |
735 | | - "enabled": True, |
736 | | - "description": policy.name, |
737 | | - "config": policy.configuration, |
738 | | - } |
739 | | - for policy in endpoint.policies |
740 | | - ] |
741 | | - connection_config = { |
742 | | - "path": f"/api/v1/endpoints/{endpoint.slug}/query", |
743 | | - } |
744 | | - |
745 | | - payload = { |
746 | | - "name": endpoint.name, |
747 | | - "description": endpoint.summary or "", |
748 | | - "type": endpoint_type, |
749 | | - "visibility": "public", |
750 | | - "version": "0.1.0", |
751 | | - "readme": endpoint.description or "", |
752 | | - "slug": endpoint.slug, |
753 | | - "policies": policies, |
754 | | - "connect": [ |
755 | | - { |
756 | | - "type": "https", |
757 | | - "enabled": True, |
758 | | - "description": "", |
759 | | - "config": connection_config, |
760 | | - } |
761 | | - ], |
762 | | - } |
| 729 | + payload = self._build_publish_payload(endpoint) |
763 | 730 | await client.publish_endpoint(payload, overwrite=True) |
764 | 731 |
|
765 | 732 | except SyftHubError as e: |
@@ -920,3 +887,125 @@ async def _check_marketplace_availability( |
920 | 887 | available=None, |
921 | 888 | error=str(e), |
922 | 889 | ) |
| 890 | + |
| 891 | + def _build_publish_payload(self, endpoint: Endpoint) -> dict[str, Any]: |
| 892 | + """Build the publish payload for an endpoint. |
| 893 | +
|
| 894 | + Args: |
| 895 | + endpoint: Endpoint entity |
| 896 | +
|
| 897 | + Returns: |
| 898 | + Dict payload for publish/sync APIs |
| 899 | + """ |
| 900 | + endpoint_type = "model" if endpoint.model_id is not None else "data_source" |
| 901 | + |
| 902 | + policies = [ |
| 903 | + { |
| 904 | + "type": policy.policy_type, |
| 905 | + "version": "1.0", |
| 906 | + "enabled": True, |
| 907 | + "description": policy.name, |
| 908 | + "config": policy.configuration, |
| 909 | + } |
| 910 | + for policy in endpoint.policies |
| 911 | + ] |
| 912 | + |
| 913 | + connection_config = { |
| 914 | + "path": f"/api/v1/endpoints/{endpoint.slug}/query", |
| 915 | + } |
| 916 | + |
| 917 | + return { |
| 918 | + "name": endpoint.name, |
| 919 | + "description": endpoint.summary or "", |
| 920 | + "type": endpoint_type, |
| 921 | + "visibility": "public", |
| 922 | + "version": "0.1.0", |
| 923 | + "readme": endpoint.description or "", |
| 924 | + "slug": endpoint.slug, |
| 925 | + "policies": policies, |
| 926 | + "connect": [ |
| 927 | + { |
| 928 | + "type": "https", |
| 929 | + "enabled": True, |
| 930 | + "description": "", |
| 931 | + "config": connection_config, |
| 932 | + } |
| 933 | + ], |
| 934 | + } |
| 935 | + |
| 936 | + async def sync_endpoints_to_marketplaces( |
| 937 | + self, tenant: Tenant |
| 938 | + ) -> dict[str, list[str]]: |
| 939 | + """Sync all published endpoints to their respective marketplaces. |
| 940 | +
|
| 941 | + Groups endpoints by marketplace and calls sync_endpoints API for each. |
| 942 | +
|
| 943 | + Args: |
| 944 | + tenant: Tenant context |
| 945 | +
|
| 946 | + Returns: |
| 947 | + Dict mapping marketplace_id -> list of synced endpoint slugs |
| 948 | + """ |
| 949 | + if not self.marketplace_repository: |
| 950 | + logger.warning("Marketplace repository not configured, skipping sync") |
| 951 | + return {} |
| 952 | + |
| 953 | + # Get all published endpoints |
| 954 | + endpoints = await self.endpoint_repository.get_published_endpoints(tenant.id) |
| 955 | + if not endpoints: |
| 956 | + logger.debug("No published endpoints to sync") |
| 957 | + return {} |
| 958 | + |
| 959 | + # Group endpoints by marketplace |
| 960 | + marketplace_endpoints: dict[UUID, list[Endpoint]] = {} |
| 961 | + for endpoint in endpoints: |
| 962 | + for marketplace_id in endpoint.published_to: |
| 963 | + marketplace_endpoints.setdefault(marketplace_id, []).append(endpoint) |
| 964 | + |
| 965 | + results: dict[str, list[str]] = {} |
| 966 | + |
| 967 | + # Sync to each marketplace |
| 968 | + for marketplace_id, eps in marketplace_endpoints.items(): |
| 969 | + try: |
| 970 | + marketplace = await self.marketplace_repository.get_by_id( |
| 971 | + UUID(marketplace_id), tenant.id |
| 972 | + ) |
| 973 | + if not marketplace: |
| 974 | + logger.warning(f"Marketplace {marketplace_id} not found, skipping") |
| 975 | + continue |
| 976 | + |
| 977 | + if not marketplace.is_active: |
| 978 | + logger.warning(f"Marketplace {marketplace_id} not active, skipping") |
| 979 | + continue |
| 980 | + |
| 981 | + if not marketplace.email or not marketplace.password: |
| 982 | + logger.warning( |
| 983 | + f"Marketplace {marketplace_id} missing credentials, skipping" |
| 984 | + ) |
| 985 | + continue |
| 986 | + |
| 987 | + # Build payloads |
| 988 | + payloads = [self._build_publish_payload(ep) for ep in eps] |
| 989 | + |
| 990 | + # Call sync API |
| 991 | + async with SyftHubClient(base_url=marketplace.url) as client: |
| 992 | + await client.login( |
| 993 | + username=marketplace.email, password=marketplace.password |
| 994 | + ) |
| 995 | + await client.sync_endpoints(payloads) |
| 996 | + |
| 997 | + results[marketplace_id] = [ep.slug for ep in eps] |
| 998 | + logger.info( |
| 999 | + f"Synced {len(eps)} endpoints to marketplace {marketplace.name}" |
| 1000 | + ) |
| 1001 | + |
| 1002 | + except SyftHubError as e: |
| 1003 | + logger.warning( |
| 1004 | + f"Failed to sync to marketplace {marketplace_id}: {e.message}" |
| 1005 | + ) |
| 1006 | + except Exception as e: |
| 1007 | + logger.error( |
| 1008 | + f"Unexpected error syncing to marketplace {marketplace_id}: {e}" |
| 1009 | + ) |
| 1010 | + |
| 1011 | + return results |
0 commit comments