Skip to content
This repository was archived by the owner on Feb 20, 2025. It is now read-only.

Commit a66bcac

Browse files
committed
cleanup: pydantic for parsing claims
1 parent ed35943 commit a66bcac

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

hatchet_sdk/token.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,31 @@
11
import base64
2-
import json
3-
from typing import Any
42

3+
from pydantic import BaseModel
54

6-
def get_tenant_id_from_jwt(token: str) -> str:
7-
claims = extract_claims_from_jwt(token)
85

9-
return claims.get("sub")
6+
class Claims(BaseModel):
7+
sub: str
8+
server_url: str
9+
grpc_broadcast_address: str
10+
11+
12+
def get_tenant_id_from_jwt(token: str) -> str:
13+
return extract_claims_from_jwt(token).sub
1014

1115

1216
def get_addresses_from_jwt(token: str) -> tuple[str, str]:
1317
claims = extract_claims_from_jwt(token)
1418

15-
return claims["server_url"], claims["grpc_broadcast_address"]
19+
return claims.server_url, claims.grpc_broadcast_address
1620

1721

18-
def extract_claims_from_jwt(token: str) -> dict[str, Any]:
22+
def extract_claims_from_jwt(token: str) -> Claims:
1923
parts = token.split(".")
2024
if len(parts) != 3:
2125
raise ValueError("Invalid token format")
2226

2327
claims_part = parts[1]
2428
claims_part += "=" * ((4 - len(claims_part) % 4) % 4) # Padding for base64 decoding
2529
claims_data = base64.urlsafe_b64decode(claims_part)
26-
claims = json.loads(claims_data)
2730

28-
return claims
31+
return Claims.model_validate_json(claims_data)

0 commit comments

Comments
 (0)