Skip to content

Commit d26c0c6

Browse files
committed
fix: mypy typing
1 parent acfe3dc commit d26c0c6

File tree

6 files changed

+34
-18
lines changed

6 files changed

+34
-18
lines changed

codegen-on-oss/codegen_on_oss/cli.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def cli():
4949
)
5050
def run_one(
5151
url: str,
52-
cache_dir: str = cachedir,
53-
output_path: Path = "metrics.csv",
52+
cache_dir: str | Path = str(cachedir),
53+
output_path: str = "metrics.csv",
5454
commit_hash: str | None = None,
55-
error_output_path: Path = cachedir / "errors.log",
55+
error_output_path: Path = str(cachedir / "errors.log"),
5656
debug: bool = False,
5757
):
5858
"""
@@ -69,7 +69,7 @@ def run_one(
6969
@cli.command()
7070
@click.option(
7171
"--source",
72-
type=click.Choice(all_sources.keys()),
72+
type=click.Choice(list(all_sources.keys())),
7373
default="csv",
7474
)
7575
@click.option(
@@ -97,9 +97,9 @@ def run_one(
9797
)
9898
def run(
9999
source: str,
100-
output_path: Path,
101-
error_output_path: Path,
102-
cache_dir: Path,
100+
output_path: str,
101+
error_output_path: str,
102+
cache_dir: str,
103103
debug: bool,
104104
):
105105
"""

codegen-on-oss/codegen_on_oss/metrics.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,14 @@
33
import os
44
import time
55
from contextlib import contextmanager
6-
from typing import Any
6+
from pathlib import Path
7+
from typing import TYPE_CHECKING, Any
78

89
import psutil
9-
from loguru import _Logger as LoggerType
10+
11+
if TYPE_CHECKING:
12+
# Logger only available in type checking context.
13+
from loguru import Logger # type: ignore[attr-defined]
1014

1115

1216
class MetricsProfiler:
@@ -30,7 +34,7 @@ def __init__(self, output_path: str):
3034
self.output_path = output_path
3135

3236
@contextmanager
33-
def start_profiler(self, name: str, revision: str, logger: LoggerType):
37+
def start_profiler(self, name: str, revision: str, logger: "Logger"):
3438
"""
3539
Starts a new profiling session for a given profile name.
3640
Returns a MetricsProfile instance that you can use to mark measurements.
@@ -55,14 +59,17 @@ class MetricsProfile:
5559
to a CSV file.
5660
"""
5761

62+
if TYPE_CHECKING:
63+
logger: "Logger"
64+
measurements: list[dict[str, Any]]
65+
5866
def __init__(
59-
self, name: str, revision: str, output_path: str | None, logger: LoggerType
67+
self, name: str, revision: str, output_path: str | None, logger: "Logger"
6068
):
6169
self.name = name
6270
self.revision = revision
6371
self.output_path = output_path
6472
self.logger = logger
65-
self.measurements = []
6673

6774
# Capture initial metrics.
6875
self.start_time = time.perf_counter()
@@ -152,6 +159,8 @@ def _write_csv(self, measurement: dict[str, Any]):
152159
return
153160

154161
file_exists = os.path.isfile(self.output_path)
162+
if not file_exists:
163+
Path(self.output_path).parent.mkdir(parents=True, exist_ok=True)
155164

156165
with open(self.output_path, mode="a", newline="") as csv_file:
157166
writer = csv.DictWriter(csv_file, fieldnames=measurement.keys())

codegen-on-oss/codegen_on_oss/parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, *args, projects: "list[ProjectConfig]", **kwargs):
7676
raise ParseRunError(validation_status)
7777

7878
ProfiledCodebase.from_repo(
79-
repo_name, tmp_dir=self.repo_dir, commit=commit_hash
79+
repo_name, tmp_dir=str(self.repo_dir.absolute()), commit=commit_hash
8080
)
8181

8282
def gc(self):

codegen-on-oss/codegen_on_oss/sources/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class RepoSource(Generic[SettingsType]):
3333
"""
3434

3535
source_type: ClassVar[str]
36-
settings_cls: ClassVar[type[SettingsType]]
36+
settings_cls: ClassVar[type[SourceSettings]]
3737

3838
if TYPE_CHECKING:
3939
settings: SourceSettings

codegen-on-oss/codegen_on_oss/sources/github_source.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Iterator
2-
from typing import ClassVar, Literal
2+
from typing import TYPE_CHECKING, ClassVar, Literal
33

44
from github import Auth, Github # nosemgrep
55

@@ -26,9 +26,16 @@ class GithubSettings(SourceSettings, env_prefix="GITHUB_"):
2626

2727

2828
class GithubSource(RepoSource[GithubSettings]):
29+
"""
30+
Source for Github repositories via Github Search API
31+
"""
32+
33+
if TYPE_CHECKING:
34+
github_client: Github
35+
settings: GithubSettings
36+
2937
source_type: ClassVar[str] = "github"
30-
settings_cls: ClassVar[type[SourceSettings]] = GithubSettings
31-
github_client: Github
38+
settings_cls: ClassVar[type[GithubSettings]] = GithubSettings
3239

3340
def __init__(self) -> None:
3441
super().__init__()

codegen-on-oss/modal_run.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def parse_repo_on_modal(
6666
except Exception as e:
6767
logger.exception(f"Error parsing repository {repo_url}: {e}")
6868

69-
BucketStore(bucket_name=os.getenv("BUCKET_NAME")).upload_run(
69+
BucketStore(bucket_name=os.getenv("BUCKET_NAME", "codegen-oss-parse")).upload_run(
7070
repo_source,
7171
log_output_path,
7272
metrics_output_path,

0 commit comments

Comments
 (0)