1
+ # This new API endpoint will allow for dynamic exploration of user retention for a single market and custom date.
2
+
3
+ import os
4
+ from datetime import datetime , timedelta , timezone
5
+ from typing import Dict , List , Tuple , Set , Optional , Any
6
+
7
+ import pandas as pd
8
+ from dateutil import tz , parser
9
+ from pyathena import connect
10
+ import warnings
11
+ from fastapi import APIRouter , HTTPException , Query
12
+ from pydantic import BaseModel
13
+ import logging
14
+ import json
15
+
16
+ import boto3
17
+
18
+ def load_markets_from_json (file_path : str ) -> Dict [str , Dict [str , Any ]]:
19
+ """Loads market data from a JSON file and formats it for the API."""
20
+ try :
21
+ with open (file_path , 'r' ) as f :
22
+ markets_data = json .load (f )
23
+
24
+ formatted_markets = {}
25
+ for market in markets_data :
26
+ formatted_markets [market ["marketName" ]] = {
27
+ "index" : market ["marketIndex" ],
28
+ "launch_ts" : market ["launchTs" ], # Keep original launch_ts for reference if needed
29
+ "category" : market ["category" ]
30
+ }
31
+ logger .info (f"Successfully loaded and formatted { len (formatted_markets )} markets from { file_path } " )
32
+ return formatted_markets
33
+ except FileNotFoundError :
34
+ logger .error (f"Market file not found at { file_path } . API will not have market data." )
35
+ return {}
36
+ except json .JSONDecodeError :
37
+ logger .error (f"Error decoding JSON from { file_path } ." )
38
+ return {}
39
+ except Exception as e :
40
+ logger .error (f"An unexpected error occurred while loading markets: { e } " )
41
+ return {}
42
+
43
+ def log_current_identity ():
44
+ try :
45
+ sts = boto3 .client ("sts" )
46
+ identity = sts .get_caller_identity ()
47
+ logger .info (f"Running as: { identity } " )
48
+ except Exception as e :
49
+ logger .warning (f"Could not determine AWS identity: { e } " )
50
+
51
+ warnings .filterwarnings ("ignore" , category = UserWarning )
52
+
53
+ logging .basicConfig (level = logging .INFO )
54
+ logger = logging .getLogger (__name__ )
55
+
56
+ router = APIRouter ()
57
+
58
+ ALL_MARKETS = load_markets_from_json ("shared/markets.json" )
59
+
60
+ NEW_TRADER_WINDOW_DAYS : int = 7
61
+ RETENTION_WINDOWS_DAYS : List [int ] = [14 , 28 ]
62
+ CHUNK_DAYS : int = 28
63
+
64
+ DATABASE = os .environ .get ("ATHENA_DATABASE" , "mainnet-beta-archive" )
65
+ REGION = os .environ .get ("AWS_REGION" , "eu-west-1" )
66
+ S3_OUTPUT = os .environ .get ("ATHENA_S3_OUTPUT" , "s3://mainnet-beta-data-ingestion-bucket/athena/" )
67
+
68
+ class RetentionExplorerItem (BaseModel ):
69
+ market : str
70
+ category : List [str ]
71
+ start_date : str
72
+ new_traders : int
73
+ new_traders_list : List [str ]
74
+ retained_users_14d : int
75
+ retention_ratio_14d : float
76
+ retained_users_14d_list : List [str ]
77
+ retained_users_28d : int
78
+ retention_ratio_28d : float
79
+ retained_users_28d_list : List [str ]
80
+
81
+ UTC = tz .tzutc ()
82
+
83
+ def dt_from_ms (ms : int ) -> datetime :
84
+ return datetime .fromtimestamp (ms / 1_000 , tz = UTC )
85
+
86
+ def partition_tuples (start : datetime , days : int ) -> Set [Tuple [str , str , str ]]:
87
+ return {
88
+ (d .strftime ("%Y" ), d .strftime ("%m" ), d .strftime ("%d" ))
89
+ for d in (start + timedelta (n ) for n in range (days ))
90
+ }
91
+
92
+ def partition_pred (parts : Set [Tuple [str , str , str ]]) -> str :
93
+ lines = [
94
+ f"(year='{ y } ' AND month='{ m } ' AND day='{ d } ')" for y , m , d in sorted (parts )
95
+ ]
96
+ return " OR " .join (lines )
97
+
98
+ def sql_new_traders (mkt_idx : int , start_dt : datetime ) -> str :
99
+ parts = partition_pred (partition_tuples (start_dt , NEW_TRADER_WINDOW_DAYS ))
100
+ return f"""
101
+ SELECT "user",
102
+ MIN(slot) AS first_slot,
103
+ MIN(ts) AS first_ts
104
+ FROM eventtype_orderrecord
105
+ WHERE ({ parts } )
106
+ AND "order".marketindex = { mkt_idx }
107
+ AND ("order".orderid = 0 OR "order".orderid = 1)
108
+ GROUP BY "user"
109
+ """
110
+
111
+ def sql_retention_users_chunk (traders : List [str ],
112
+ mkt_idx : int ,
113
+ chunk_start : datetime ,
114
+ chunk_days : int ) -> str :
115
+ chunk_end = chunk_start + timedelta (days = chunk_days )
116
+ start_ts = int (chunk_start .timestamp ())
117
+ end_ts = int (chunk_end .timestamp ())
118
+ from_date = chunk_start .strftime ('%Y%m%d' )
119
+ to_date = chunk_end .strftime ('%Y%m%d' )
120
+ trader_list = "', '" .join (traders )
121
+
122
+ return f'''
123
+ WITH time_range AS (
124
+ SELECT
125
+ { start_ts } AS from_ts,
126
+ { end_ts } AS to_ts,
127
+ '{ from_date } ' AS from_date,
128
+ '{ to_date } ' AS to_date
129
+ )
130
+ SELECT DISTINCT "user"
131
+ FROM eventtype_orderrecord, time_range
132
+ WHERE CAST(ts AS INT) BETWEEN time_range.from_ts AND time_range.to_ts
133
+ AND CONCAT(year, month, day) BETWEEN time_range.from_date AND time_range.to_date
134
+ AND "order".marketindex <> { mkt_idx }
135
+ AND "user" IN ('{ trader_list } ')
136
+ '''
137
+
138
+ async def calculate_retention_for_market (market_name : str , start_date_str : str ) -> Dict [str , Any ]:
139
+ conn = None
140
+ try :
141
+ start_date = parser .parse (start_date_str ).replace (tzinfo = UTC )
142
+ market_config = ALL_MARKETS .get (market_name )
143
+ if not market_config :
144
+ raise HTTPException (status_code = 404 , detail = f"Market '{ market_name } ' not found." )
145
+
146
+ logger .info (f"Connecting to Athena. S3 staging: { S3_OUTPUT } , Region: { REGION } , DB: { DATABASE } " )
147
+ conn = connect (s3_staging_dir = S3_OUTPUT , region_name = REGION , schema_name = DATABASE )
148
+ logger .info ("Successfully connected to Athena." )
149
+ log_current_identity ()
150
+
151
+ # 1. Find new traders for the given market and date
152
+ logger .info (f"Scanning for new traders for { market_name } from { start_date_str } ..." )
153
+ q_new_traders = sql_new_traders (market_config ["index" ], start_date )
154
+ new_traders_df = pd .read_sql (q_new_traders , conn )
155
+ logger .info (f"Found { len (new_traders_df )} new traders for { market_name } ." )
156
+
157
+ mkt_traders = new_traders_df ["user" ].tolist ()
158
+ new_traders_count = len (mkt_traders )
159
+
160
+ result = {
161
+ "market" : market_name ,
162
+ "category" : market_config .get ("category" , []),
163
+ "start_date" : start_date_str ,
164
+ "new_traders" : new_traders_count ,
165
+ "new_traders_list" : mkt_traders ,
166
+ }
167
+
168
+ # 2. Calculate retention for each window
169
+ if not mkt_traders :
170
+ for win in RETENTION_WINDOWS_DAYS :
171
+ result [f"retained_users_{ win } d" ] = 0
172
+ result [f"retention_ratio_{ win } d" ] = 0.0
173
+ result [f"retained_users_{ win } d_list" ] = []
174
+ return result
175
+
176
+ retention_period_start_dt = start_date
177
+ for win in RETENTION_WINDOWS_DAYS :
178
+ offset = 0
179
+ retained_set : Set [str ] = set ()
180
+
181
+ while offset < win :
182
+ chunk_start_dt = retention_period_start_dt + timedelta (days = offset )
183
+ span = min (CHUNK_DAYS , win - offset )
184
+ if span <= 0 : break
185
+
186
+ logger .info (f"Fetching retention for { market_name } , window { win } d, chunk: { chunk_start_dt .strftime ('%Y-%m-%d' )} for { span } days" )
187
+ q_retention_chunk = sql_retention_users_chunk (mkt_traders , market_config ["index" ], chunk_start_dt , span )
188
+ retained_users_df = pd .read_sql (q_retention_chunk , conn )
189
+ retained_set .update (retained_users_df ["user" ].tolist ())
190
+ offset += CHUNK_DAYS
191
+
192
+ retained_list = sorted (list (retained_set ))
193
+ retained_count = len (retained_list )
194
+ retention_ratio = (retained_count / new_traders_count ) if new_traders_count > 0 else 0.0
195
+
196
+ result [f"retained_users_{ win } d" ] = retained_count
197
+ result [f"retention_ratio_{ win } d" ] = round (retention_ratio , 4 )
198
+ result [f"retained_users_{ win } d_list" ] = retained_list
199
+
200
+ logger .info (f"Successfully calculated retention for { market_name } ." )
201
+ return result
202
+
203
+ except Exception as e :
204
+ logger .error (f"Error in calculate_retention_for_market: { e } " , exc_info = True )
205
+ raise HTTPException (status_code = 500 , detail = f"Failed to process retention data: { str (e )} " )
206
+ finally :
207
+ if conn :
208
+ conn .close ()
209
+ logger .info ("Athena connection closed." )
210
+
211
+ @router .get ("/markets" , response_model = List [str ])
212
+ async def get_available_markets ():
213
+ """Returns a list of available market names for the explorer."""
214
+ if not ALL_MARKETS :
215
+ logger .warning ("No markets loaded from shared/markets.json" )
216
+ return []
217
+ return sorted (list (ALL_MARKETS .keys ()))
218
+
219
+
220
+ @router .get ("/calculate" , response_model = RetentionExplorerItem )
221
+ async def get_retention_for_market (
222
+ market_name : str = Query (..., description = "The name of the market to analyze." ),
223
+ start_date : str = Query (..., description = "The start date for the analysis (YYYY-MM-DD)." )
224
+ ):
225
+ """
226
+ Calculates user retention for a specific market from a given start date.
227
+ - Identifies 'new traders' within 7 days of the start date for that market.
228
+ - Measures retention in other markets at 14 and 28 days.
229
+ """
230
+ try :
231
+ logger .info (f"Received request for /calculate: market='{ market_name } ', date='{ start_date } '" )
232
+ # Input validation for date format can be added here if needed
233
+ result_data = await calculate_retention_for_market (market_name , start_date )
234
+ return RetentionExplorerItem (** result_data )
235
+ except HTTPException as http_exc :
236
+ raise http_exc
237
+ except Exception as e :
238
+ logger .error (f"Unhandled error in /calculate endpoint: { e } " , exc_info = True )
239
+ raise HTTPException (status_code = 500 , detail = "An internal server error occurred during calculation." )
0 commit comments