Skip to content

Commit 9525b6c

Browse files
authored
Merge pull request #42 from drift-labs:goldhaxx/DATA-75/add-btc-perp-to-user-retention-report
Goldhaxx/DATA-75/add-btc-perp-to-user-retention-report
2 parents 4e440f0 + 1b88cf9 commit 9525b6c

11 files changed

+3345
-484
lines changed

backend/api/user_retention_api.py

Lines changed: 0 additions & 423 deletions
This file was deleted.
Lines changed: 239 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,239 @@
1+
# This new API endpoint will allow for dynamic exploration of user retention for a single market and custom date.
2+
3+
import os
4+
from datetime import datetime, timedelta, timezone
5+
from typing import Dict, List, Tuple, Set, Optional, Any
6+
7+
import pandas as pd
8+
from dateutil import tz, parser
9+
from pyathena import connect
10+
import warnings
11+
from fastapi import APIRouter, HTTPException, Query
12+
from pydantic import BaseModel
13+
import logging
14+
import json
15+
16+
import boto3
17+
18+
def load_markets_from_json(file_path: str) -> Dict[str, Dict[str, Any]]:
19+
"""Loads market data from a JSON file and formats it for the API."""
20+
try:
21+
with open(file_path, 'r') as f:
22+
markets_data = json.load(f)
23+
24+
formatted_markets = {}
25+
for market in markets_data:
26+
formatted_markets[market["marketName"]] = {
27+
"index": market["marketIndex"],
28+
"launch_ts": market["launchTs"], # Keep original launch_ts for reference if needed
29+
"category": market["category"]
30+
}
31+
logger.info(f"Successfully loaded and formatted {len(formatted_markets)} markets from {file_path}")
32+
return formatted_markets
33+
except FileNotFoundError:
34+
logger.error(f"Market file not found at {file_path}. API will not have market data.")
35+
return {}
36+
except json.JSONDecodeError:
37+
logger.error(f"Error decoding JSON from {file_path}.")
38+
return {}
39+
except Exception as e:
40+
logger.error(f"An unexpected error occurred while loading markets: {e}")
41+
return {}
42+
43+
def log_current_identity():
44+
try:
45+
sts = boto3.client("sts")
46+
identity = sts.get_caller_identity()
47+
logger.info(f"Running as: {identity}")
48+
except Exception as e:
49+
logger.warning(f"Could not determine AWS identity: {e}")
50+
51+
warnings.filterwarnings("ignore", category=UserWarning)
52+
53+
logging.basicConfig(level=logging.INFO)
54+
logger = logging.getLogger(__name__)
55+
56+
router = APIRouter()
57+
58+
ALL_MARKETS = load_markets_from_json("shared/markets.json")
59+
60+
NEW_TRADER_WINDOW_DAYS: int = 7
61+
RETENTION_WINDOWS_DAYS: List[int] = [14, 28]
62+
CHUNK_DAYS: int = 28
63+
64+
DATABASE = os.environ.get("ATHENA_DATABASE", "mainnet-beta-archive")
65+
REGION = os.environ.get("AWS_REGION", "eu-west-1")
66+
S3_OUTPUT = os.environ.get("ATHENA_S3_OUTPUT", "s3://mainnet-beta-data-ingestion-bucket/athena/")
67+
68+
class RetentionExplorerItem(BaseModel):
69+
market: str
70+
category: List[str]
71+
start_date: str
72+
new_traders: int
73+
new_traders_list: List[str]
74+
retained_users_14d: int
75+
retention_ratio_14d: float
76+
retained_users_14d_list: List[str]
77+
retained_users_28d: int
78+
retention_ratio_28d: float
79+
retained_users_28d_list: List[str]
80+
81+
UTC = tz.tzutc()
82+
83+
def dt_from_ms(ms: int) -> datetime:
84+
return datetime.fromtimestamp(ms / 1_000, tz=UTC)
85+
86+
def partition_tuples(start: datetime, days: int) -> Set[Tuple[str, str, str]]:
87+
return {
88+
(d.strftime("%Y"), d.strftime("%m"), d.strftime("%d"))
89+
for d in (start + timedelta(n) for n in range(days))
90+
}
91+
92+
def partition_pred(parts: Set[Tuple[str, str, str]]) -> str:
93+
lines = [
94+
f"(year='{y}' AND month='{m}' AND day='{d}')" for y, m, d in sorted(parts)
95+
]
96+
return " OR ".join(lines)
97+
98+
def sql_new_traders(mkt_idx: int, start_dt: datetime) -> str:
99+
parts = partition_pred(partition_tuples(start_dt, NEW_TRADER_WINDOW_DAYS))
100+
return f"""
101+
SELECT "user",
102+
MIN(slot) AS first_slot,
103+
MIN(ts) AS first_ts
104+
FROM eventtype_orderrecord
105+
WHERE ({parts})
106+
AND "order".marketindex = {mkt_idx}
107+
AND ("order".orderid = 0 OR "order".orderid = 1)
108+
GROUP BY "user"
109+
"""
110+
111+
def sql_retention_users_chunk(traders: List[str],
112+
mkt_idx: int,
113+
chunk_start: datetime,
114+
chunk_days: int) -> str:
115+
chunk_end = chunk_start + timedelta(days=chunk_days)
116+
start_ts = int(chunk_start.timestamp())
117+
end_ts = int(chunk_end.timestamp())
118+
from_date = chunk_start.strftime('%Y%m%d')
119+
to_date = chunk_end.strftime('%Y%m%d')
120+
trader_list = "', '".join(traders)
121+
122+
return f'''
123+
WITH time_range AS (
124+
SELECT
125+
{start_ts} AS from_ts,
126+
{end_ts} AS to_ts,
127+
'{from_date}' AS from_date,
128+
'{to_date}' AS to_date
129+
)
130+
SELECT DISTINCT "user"
131+
FROM eventtype_orderrecord, time_range
132+
WHERE CAST(ts AS INT) BETWEEN time_range.from_ts AND time_range.to_ts
133+
AND CONCAT(year, month, day) BETWEEN time_range.from_date AND time_range.to_date
134+
AND "order".marketindex <> {mkt_idx}
135+
AND "user" IN ('{trader_list}')
136+
'''
137+
138+
async def calculate_retention_for_market(market_name: str, start_date_str: str) -> Dict[str, Any]:
139+
conn = None
140+
try:
141+
start_date = parser.parse(start_date_str).replace(tzinfo=UTC)
142+
market_config = ALL_MARKETS.get(market_name)
143+
if not market_config:
144+
raise HTTPException(status_code=404, detail=f"Market '{market_name}' not found.")
145+
146+
logger.info(f"Connecting to Athena. S3 staging: {S3_OUTPUT}, Region: {REGION}, DB: {DATABASE}")
147+
conn = connect(s3_staging_dir=S3_OUTPUT, region_name=REGION, schema_name=DATABASE)
148+
logger.info("Successfully connected to Athena.")
149+
log_current_identity()
150+
151+
# 1. Find new traders for the given market and date
152+
logger.info(f"Scanning for new traders for {market_name} from {start_date_str}...")
153+
q_new_traders = sql_new_traders(market_config["index"], start_date)
154+
new_traders_df = pd.read_sql(q_new_traders, conn)
155+
logger.info(f"Found {len(new_traders_df)} new traders for {market_name}.")
156+
157+
mkt_traders = new_traders_df["user"].tolist()
158+
new_traders_count = len(mkt_traders)
159+
160+
result = {
161+
"market": market_name,
162+
"category": market_config.get("category", []),
163+
"start_date": start_date_str,
164+
"new_traders": new_traders_count,
165+
"new_traders_list": mkt_traders,
166+
}
167+
168+
# 2. Calculate retention for each window
169+
if not mkt_traders:
170+
for win in RETENTION_WINDOWS_DAYS:
171+
result[f"retained_users_{win}d"] = 0
172+
result[f"retention_ratio_{win}d"] = 0.0
173+
result[f"retained_users_{win}d_list"] = []
174+
return result
175+
176+
retention_period_start_dt = start_date
177+
for win in RETENTION_WINDOWS_DAYS:
178+
offset = 0
179+
retained_set: Set[str] = set()
180+
181+
while offset < win:
182+
chunk_start_dt = retention_period_start_dt + timedelta(days=offset)
183+
span = min(CHUNK_DAYS, win - offset)
184+
if span <= 0: break
185+
186+
logger.info(f"Fetching retention for {market_name}, window {win}d, chunk: {chunk_start_dt.strftime('%Y-%m-%d')} for {span} days")
187+
q_retention_chunk = sql_retention_users_chunk(mkt_traders, market_config["index"], chunk_start_dt, span)
188+
retained_users_df = pd.read_sql(q_retention_chunk, conn)
189+
retained_set.update(retained_users_df["user"].tolist())
190+
offset += CHUNK_DAYS
191+
192+
retained_list = sorted(list(retained_set))
193+
retained_count = len(retained_list)
194+
retention_ratio = (retained_count / new_traders_count) if new_traders_count > 0 else 0.0
195+
196+
result[f"retained_users_{win}d"] = retained_count
197+
result[f"retention_ratio_{win}d"] = round(retention_ratio, 4)
198+
result[f"retained_users_{win}d_list"] = retained_list
199+
200+
logger.info(f"Successfully calculated retention for {market_name}.")
201+
return result
202+
203+
except Exception as e:
204+
logger.error(f"Error in calculate_retention_for_market: {e}", exc_info=True)
205+
raise HTTPException(status_code=500, detail=f"Failed to process retention data: {str(e)}")
206+
finally:
207+
if conn:
208+
conn.close()
209+
logger.info("Athena connection closed.")
210+
211+
@router.get("/markets", response_model=List[str])
212+
async def get_available_markets():
213+
"""Returns a list of available market names for the explorer."""
214+
if not ALL_MARKETS:
215+
logger.warning("No markets loaded from shared/markets.json")
216+
return []
217+
return sorted(list(ALL_MARKETS.keys()))
218+
219+
220+
@router.get("/calculate", response_model=RetentionExplorerItem)
221+
async def get_retention_for_market(
222+
market_name: str = Query(..., description="The name of the market to analyze."),
223+
start_date: str = Query(..., description="The start date for the analysis (YYYY-MM-DD).")
224+
):
225+
"""
226+
Calculates user retention for a specific market from a given start date.
227+
- Identifies 'new traders' within 7 days of the start date for that market.
228+
- Measures retention in other markets at 14 and 28 days.
229+
"""
230+
try:
231+
logger.info(f"Received request for /calculate: market='{market_name}', date='{start_date}'")
232+
# Input validation for date format can be added here if needed
233+
result_data = await calculate_retention_for_market(market_name, start_date)
234+
return RetentionExplorerItem(**result_data)
235+
except HTTPException as http_exc:
236+
raise http_exc
237+
except Exception as e:
238+
logger.error(f"Unhandled error in /calculate endpoint: {e}", exc_info=True)
239+
raise HTTPException(status_code=500, detail="An internal server error occurred during calculation.")
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# This API endpoint serves a static, pre-computed user retention summary from a local JSON file.
2+
3+
import os
4+
from typing import List, Dict, Any, Optional
5+
from fastapi import APIRouter, HTTPException
6+
from pydantic import BaseModel
7+
import logging
8+
import json
9+
10+
# Configure logging
11+
logging.basicConfig(level=logging.INFO)
12+
logger = logging.getLogger(__name__)
13+
14+
router = APIRouter()
15+
16+
# Define the path for the static data file
17+
STATIC_DATA_PATH = "shared/user_retention_summary.json"
18+
19+
# ──────────────────────── 1A. Pydantic Models ──────────────────────── #
20+
21+
class RetentionSummaryItem(BaseModel):
22+
market: str
23+
category: List[str]
24+
new_traders: int
25+
new_traders_list: Optional[List[str]] = None
26+
retained_users_14d: Optional[int] = None
27+
retention_ratio_14d: Optional[float] = None
28+
retained_users_14d_list: Optional[List[str]] = None
29+
retained_users_28d: Optional[int] = None
30+
retention_ratio_28d: Optional[float] = None
31+
retained_users_28d_list: Optional[List[str]] = None
32+
33+
class Config:
34+
orm_mode = True
35+
36+
# ──────────────────────── 2. API Endpoint ───────────────────────── #
37+
38+
@router.get("/summary", response_model=List[RetentionSummaryItem])
39+
async def get_user_retention_summary():
40+
"""
41+
Provides a pre-computed summary of user retention for "hype" markets
42+
by reading from a static JSON file.
43+
"""
44+
logger.info(f"Received request for /summary endpoint. Reading from {STATIC_DATA_PATH}.")
45+
46+
try:
47+
with open(STATIC_DATA_PATH, 'r') as f:
48+
data = json.load(f)
49+
50+
# Validate data with Pydantic model
51+
validated_results = [RetentionSummaryItem(**item) for item in data]
52+
logger.info(f"Successfully loaded and validated {len(validated_results)} summary items from static file.")
53+
return validated_results
54+
55+
except FileNotFoundError:
56+
logger.error(f"Static data file not found at: {STATIC_DATA_PATH}")
57+
raise HTTPException(
58+
status_code=404,
59+
detail=(
60+
"The user retention summary file was not found. "
61+
"Please ensure it has been generated and placed in the 'shared' directory."
62+
)
63+
)
64+
except json.JSONDecodeError:
65+
logger.error(f"Error decoding JSON from {STATIC_DATA_PATH}")
66+
raise HTTPException(status_code=500, detail="Failed to parse the summary data file.")
67+
except Exception as e:
68+
logger.error(f"Unhandled error in /summary endpoint: {e}", exc_info=True)
69+
raise HTTPException(status_code=500, detail=f"An internal server error occurred: {e}")

backend/app.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
ucache,
2525
vaults,
2626
high_leverage_api,
27-
user_retention_api,
27+
user_retention_summary_api,
28+
user_retention_explorer_api,
2829
)
2930
from backend.middleware.cache_middleware import CacheMiddleware
3031
from backend.middleware.readiness import ReadinessMiddleware
@@ -93,7 +94,9 @@ async def lifespan(app: FastAPI):
9394
app.include_router(market_recommender_api.router, prefix="/api/market-recommender", tags=["market-recommender"])
9495
app.include_router(open_interest_api.router, prefix="/api/open-interest", tags=["open-interest"])
9596
app.include_router(high_leverage_api.router, prefix="/api/high-leverage", tags=["high-leverage"])
96-
app.include_router(user_retention_api.router, prefix="/api/user-retention", tags=["user-retention"])
97+
app.include_router(user_retention_summary_api.router, prefix="/api/user-retention-summary", tags=["user-retention-summary"])
98+
app.include_router(user_retention_explorer_api.router, prefix="/api/user-retention-explorer", tags=["user-retention-explorer"])
99+
97100
# NOTE: All other routes should be in /api/* within the /api folder. Routes outside of /api are not exposed in k8s
98101
@app.get("/")
99102
async def root():

0 commit comments

Comments
 (0)