Skip to content
This repository was archived by the owner on Jun 13, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions api/internal/feature/views.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
import pickle
from typing import Any, Dict, List

from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.views import APIView
from shared.django_apps.rollouts.models import FeatureFlag
Expand All @@ -20,15 +22,15 @@ class FeaturesView(APIView):
skip_feature_cache = get_config("setup", "skip_feature_cache", default=False)
timeout = 300

def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
self.redis = get_redis_connection()
super().__init__(*args, **kwargs)

def get_many_from_redis(self, keys):
def get_many_from_redis(self, keys: List) -> Dict[str, Any]:
ret = self.redis.mget(keys)
return {k: pickle.loads(v) for k, v in zip(keys, ret) if v is not None}

def set_many_to_redis(self, data):
def set_many_to_redis(self, data: Dict[str, Any]) -> None:
pipeline = self.redis.pipeline()
pipeline.mset({k: pickle.dumps(v) for k, v in data.items()})

Expand All @@ -38,7 +40,7 @@ def set_many_to_redis(self, data):
pipeline.expire(key, self.timeout)
pipeline.execute()

def post(self, request):
def post(self, request: Request) -> Response:
serializer = FeatureRequestSerializer(data=request.data)
if serializer.is_valid():
flag_evaluations = {}
Expand Down
39 changes: 20 additions & 19 deletions api/internal/owner/serializers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
from datetime import datetime
from typing import Any, Dict

from dateutil.relativedelta import relativedelta
from django.conf import settings
Expand Down Expand Up @@ -37,7 +38,7 @@ class Meta:

read_only_fields = fields

def get_stats(self, obj):
def get_stats(self, obj: Owner) -> str | None:
if obj.cache and "stats" in obj.cache:
return obj.cache["stats"]

Expand All @@ -50,7 +51,7 @@ class StripeLineItemSerializer(serializers.Serializer):
plan_name = serializers.SerializerMethodField()
quantity = serializers.IntegerField()

def get_plan_name(self, line_item):
def get_plan_name(self, line_item: Dict[str, str]) -> str | None:
plan = line_item.get("plan")
if plan:
return plan.get("name")
Expand Down Expand Up @@ -85,7 +86,7 @@ class StripeDiscountSerializer(serializers.Serializer):
duration_in_months = serializers.IntegerField(source="coupon.duration_in_months")
expires = serializers.SerializerMethodField()

def get_expires(self, customer):
def get_expires(self, customer: Dict[str, Dict]) -> int | None:
coupon = customer.get("coupon")
if coupon:
months = coupon.get("duration_in_months")
Expand Down Expand Up @@ -121,7 +122,7 @@ class PlanSerializer(serializers.Serializer):
benefits = serializers.JSONField(read_only=True)
quantity = serializers.IntegerField(required=False)

def validate_value(self, value):
def validate_value(self, value: str) -> str:
current_org = self.context["view"].owner
current_owner = self.context["request"].current_owner

Expand All @@ -140,7 +141,7 @@ def validate_value(self, value):
)
return value

def validate(self, plan):
def validate(self, plan: Dict[str, Any]) -> Dict[str, Any]:
current_org = self.context["view"].owner
if current_org.account:
raise serializers.ValidationError(
Expand Down Expand Up @@ -206,7 +207,7 @@ class StripeScheduledPhaseSerializer(serializers.Serializer):
plan = serializers.SerializerMethodField()
quantity = serializers.SerializerMethodField()

def get_plan(self, phase):
def get_plan(self, phase: Dict[str, Any]) -> str:
plan_id = phase["items"][0]["plan"]
stripe_plan_dict = settings.STRIPE_PLAN_IDS
plan_name = list(stripe_plan_dict.keys())[
Expand All @@ -215,15 +216,15 @@ def get_plan(self, phase):
marketing_plan_name = PAID_PLANS[plan_name].billing_rate
return marketing_plan_name

def get_quantity(self, phase):
def get_quantity(self, phase: Dict[str, Any]) -> int:
return phase["items"][0]["quantity"]


class ScheduleDetailSerializer(serializers.Serializer):
id = serializers.CharField()
scheduled_phase = serializers.SerializerMethodField()

def get_scheduled_phase(self, schedule):
def get_scheduled_phase(self, schedule: Dict[str, Any]) -> Dict[str, Any] | None:
if len(schedule["phases"]) > 1:
return StripeScheduledPhaseSerializer(schedule["phases"][-1]).data
else:
Expand Down Expand Up @@ -291,44 +292,44 @@ class Meta:
"uses_invoice",
)

def _get_billing(self):
def _get_billing(self) -> BillingService:
current_owner = self.context["request"].current_owner
return BillingService(requesting_user=current_owner)

def get_subscription_detail(self, owner):
def get_subscription_detail(self, owner: Owner) -> Dict[str, Any] | None:
subscription_detail = self._get_billing().get_subscription(owner)
if subscription_detail:
return SubscriptionDetailSerializer(subscription_detail).data

def get_schedule_detail(self, owner):
def get_schedule_detail(self, owner: Owner) -> Dict[str, Any] | None:
schedule_detail = self._get_billing().get_schedule(owner)
if schedule_detail:
return ScheduleDetailSerializer(schedule_detail).data

def get_checkout_session_id(self, _):
def get_checkout_session_id(self, _: Any) -> str:
return self.context.get("checkout_session_id")

def get_activated_student_count(self, owner):
def get_activated_student_count(self, owner: Owner) -> int:
if owner.account:
return owner.account.activated_student_count
return owner.activated_student_count

def get_activated_user_count(self, owner):
def get_activated_user_count(self, owner: Owner) -> int:
if owner.account:
return owner.account.activated_user_count
return owner.activated_user_count

def get_delinquent(self, owner):
def get_delinquent(self, owner: Owner) -> bool:
if owner.account:
return owner.account.is_delinquent
return owner.delinquent

def get_uses_invoice(self, owner):
def get_uses_invoice(self, owner: Owner) -> bool:
if owner.account:
return owner.account.invoice_billing.filter(is_active=True).exists()
return owner.uses_invoice

def update(self, instance, validated_data):
def update(self, instance: Owner, validated_data: Dict[str, Any]) -> object:
if "pretty_plan" in validated_data:
desired_plan = validated_data.pop("pretty_plan")
checkout_session_id_or_none = self._get_billing().update_plan(
Expand Down Expand Up @@ -367,7 +368,7 @@ class Meta:
"last_pull_timestamp",
)

def update(self, instance, validated_data):
def update(self, instance: Owner, validated_data: Dict[str, Any]) -> object:
owner = self.context["view"].owner

if "activated" in validated_data:
Expand All @@ -391,7 +392,7 @@ def update(self, instance, validated_data):
# Re-fetch from DB to set activated and admin fields
return self.context["view"].get_object()

def get_last_pull_timestamp(self, obj):
def get_last_pull_timestamp(self, obj: Owner) -> str | None:
# this field comes from an annotation that may not always be applied to the queryset
if hasattr(obj, "last_pull_timestamp"):
return obj.last_pull_timestamp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def convert_yaml_to_dict(self, yaml_input: str) -> Optional[dict]:
message = f"Error at {str(e.error_location)}: {e.error_message}"
raise ValidationError(message)

def yaml_side_effects(self, old_yaml: dict, new_yaml: dict):
def yaml_side_effects(self, old_yaml: dict | None, new_yaml: dict | None) -> None:
old_yaml_branch = old_yaml and old_yaml.get("codecov", {}).get("branch")
new_yaml_branch = new_yaml and new_yaml.get("codecov", {}).get("branch")

Expand Down
3 changes: 2 additions & 1 deletion codecov_auth/management/commands/set_trial_status_values.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any

from django.core.management.base import BaseCommand, CommandParser
from django.db.models import Q
Expand All @@ -20,7 +21,7 @@ class Command(BaseCommand):
def add_arguments(self, parser: CommandParser) -> None:
parser.add_argument("trial_status_type", type=str)

def handle(self, *args, **options) -> None:
def handle(self, *args: Any, **options: Any) -> None:
trial_status_type = options.get("trial_status_type", {})

# NOT_STARTED
Expand Down
4 changes: 2 additions & 2 deletions core/commands/pull/interactors/fetch_pull_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from datetime import datetime, timedelta

from shared.django_apps.core.models import Pull
from shared.django_apps.core.models import Pull, Repository

from codecov.commands.base import BaseInteractor
from codecov.db import sync_to_async
Expand All @@ -17,7 +17,7 @@ def _should_sync_pull(self, pull: Pull | None) -> bool:
)

@sync_to_async
def execute(self, repository, id):
def execute(self, repository: Repository, id: int) -> Pull:
pull = repository.pull_requests.filter(pullid=id).first()
if self._should_sync_pull(pull):
TaskService().pulls_sync(repository.repoid, id)
Expand Down
4 changes: 2 additions & 2 deletions core/commands/repository/interactors/erase_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


class EraseRepositoryInteractor(BaseInteractor):
def validate_owner(self, owner: Owner):
def validate_owner(self, owner: Owner) -> None:
if not current_user_part_of_org(self.current_owner, owner):
raise Unauthorized()

Expand All @@ -23,7 +23,7 @@ def validate_owner(self, owner: Owner):
raise Unauthorized()

@sync_to_async
def execute(self, repo_name: str, owner: Owner):
def execute(self, repo_name: str, owner: Owner) -> None:
self.validate_owner(owner)
repo = Repository.objects.filter(author_id=owner.pk, name=repo_name).first()
if not repo:
Expand Down
2 changes: 1 addition & 1 deletion graphql_api/actions/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from timeseries.models import Interval, MeasurementName


def flags_for_repo(repository: Repository, filters: Mapping = None) -> QuerySet:
def flags_for_repo(repository: Repository, filters: Mapping = {}) -> QuerySet:
queryset = RepositoryFlag.objects.filter(
repository=repository,
deleted__isnot=True,
Expand Down
7 changes: 5 additions & 2 deletions graphql_api/types/branch/branch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Optional

from ariadne import ObjectType
from graphql import GraphQLResolveInfo

from core.models import Branch, Commit
from graphql_api.dataloader.commit import CommitLoader
Expand All @@ -9,13 +10,15 @@


@branch_bindable.field("headSha")
def resolve_head_sha(branch: Branch, info) -> str:
def resolve_head_sha(branch: Branch, info: GraphQLResolveInfo) -> str:
head = branch.head
return head


@branch_bindable.field("head")
async def resolve_head_commit(branch: Branch, info) -> Optional[Commit]:
async def resolve_head_commit(
branch: Branch, info: GraphQLResolveInfo
) -> Optional[Commit]:
head = branch.head
if head:
loader = CommitLoader.loader(info, branch.repository_id)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Any, Dict

from ariadne import UnionType
from graphql import GraphQLResolveInfo

from graphql_api.helpers.mutation import (
require_authenticated,
Expand All @@ -9,7 +12,9 @@

@wrap_error_handling_mutation
@require_authenticated
async def resolve_save_okta_config(_, info, input):
async def resolve_save_okta_config(
_: Any, info: GraphQLResolveInfo, input: Dict[str, Any]
) -> None:
command = info.context["executor"].get_command("owner")
return await command.save_okta_config(input)

Expand Down
5 changes: 4 additions & 1 deletion graphs/mixins.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from typing import Any

from django.http import HttpResponse
from rest_framework import status
from rest_framework.request import Request
from rest_framework.response import Response


class GraphBadgeAPIMixin(object):
def get(self, request, *args, **kwargs):
def get(self, request: Request, *args: Any, **kwargs: Any) -> Response:
ext = self.kwargs.get("ext")
if ext not in self.extensions:
return Response(
Expand Down
Loading
Loading