Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 146 additions & 50 deletions src/twyn/core/config_handler.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,147 @@
import logging
from dataclasses import asdict, dataclass
from enum import Enum
from os import getcwd
from pathlib import Path
from typing import Any, Optional

from tomlkit import dumps, parse
from tomlkit import TOMLDocument, dumps, parse, table

from twyn.base.constants import DEFAULT_PROJECT_TOML_FILE
from twyn.base.constants import (
DEFAULT_PROJECT_TOML_FILE,
DEFAULT_SELECTOR_METHOD,
AvailableLoggingLevels,
)
from twyn.core.exceptions import (
AllowlistPackageAlreadyExistsError,
AllowlistPackageDoesNotExistError,
TOMLError,
)

logger = logging.getLogger()


@dataclass(frozen=True)
class TwynConfiguration:
"""Fully resolved configuration for Twyn."""

dependency_file: Optional[str]
selector_method: str
logging_level: AvailableLoggingLevels
allowlist: set[str]


@dataclass(frozen=True)
class ReadTwynConfiguration:
"""Configuration for twyn as set by the user. It may have None values."""

dependency_file: Optional[str]
selector_method: Optional[str]
logging_level: Optional[AvailableLoggingLevels]
allowlist: set[str]


class ConfigHandler:
"""Read certain values into a central ConfigHandler object."""
"""Manage reading and writing configurations for Twyn."""

def __init__(self, file_path: Optional[str] = None, enforce_file: bool = True):
self._file_path = file_path or DEFAULT_PROJECT_TOML_FILE
self._enforce_file = enforce_file
self._toml = self._get_toml_as_dict()
self._twyn_data = self._get_twyn_data()

self.dependency_file: Optional[str] = self._twyn_data.get("dependency_file")
self.selector_method: Optional[str] = self._twyn_data.get("selector_method")
self.logging_level: Optional[str] = self._twyn_data.get("logging_level")
self.allowlist: set[str] = set(self._twyn_data.get("allowlist", []))
def resolve_config(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been thinking for a while, could it be that this class' responsibilities are too broad?

For instance, right now this class both creates mechanisms to create config objects and interacting with the config file, would it make sense to split it?

Say, we have something that acts as a factory, just creating config objects, and then we have the config objects themselves, who are the ones that read and write to the file.

Maybe this is not the PR to make these changes, as some methods are mandatory both when creating the object and when interacting with the file (like _read_toml) but this could be adapted whenever we move this class' implementation to use the FileHandler.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall agree that it would not hurt to separate the reading and the writing. Though I imagine they will need mechanisms to share stuff.

Having said that, this class has 3 public methods, and 150 lines or so. It's not too bad. I am not a fan of the config objects writing themselves though.

I do not really know how useful the filehandler will be for this, but I guess I'll see in the next PRs.

self,
selector_method: Optional[str] = None,
dependency_file: Optional[str] = None,
verbosity: AvailableLoggingLevels = AvailableLoggingLevels.none,
) -> TwynConfiguration:
"""Resolve the configuration for Twyn.

Given the cli flags it will return a fully resolved configuration for Twyn,
giving precedence to cli flags vs values set in the config files.

It will also handle default values, when appropriate.
"""
toml = self._read_toml()
twyn_config_data = self._get_twyn_data_from_toml(toml)

# Resolve the configuration so that it is ready to be used by Twyn,
# handling defaults etc
return TwynConfiguration(
dependency_file=dependency_file or twyn_config_data.get("dependency_file"),
selector_method=selector_method or twyn_config_data.get("selector_method", DEFAULT_SELECTOR_METHOD),
logging_level=_get_logging_level(verbosity, twyn_config_data.get("logging_level")),
allowlist=set(twyn_config_data.get("allowlist", set())),
)

def add_package_to_allowlist(self, package_name: str) -> None:
if package_name in self.allowlist:
"""Add a package to the allowlist configuration in the toml file."""
toml = self._read_toml()
config = self._get_read_config(toml)
if package_name in config.allowlist:
raise AllowlistPackageAlreadyExistsError(package_name)

self._create_allowlist_in_toml_if_not_exists()

self._toml["tool"]["twyn"]["allowlist"].append(package_name)
self._write_toml()

logger.warning(f"Package '{package_name}' successfully added to allowlist")
new_config = ReadTwynConfiguration(
dependency_file=config.dependency_file,
selector_method=config.selector_method,
logging_level=config.logging_level,
allowlist=config.allowlist | {package_name},
)
self._write_config(toml, new_config)
logger.info(f"Package '{package_name}' successfully added to allowlist")

def remove_package_from_allowlist(self, package_name: str) -> None:
if package_name not in self.allowlist:
"""Remove a package from the allowlist configuration in the toml file."""
toml = self._read_toml()
config = self._get_read_config(toml)
if package_name not in config.allowlist:
raise AllowlistPackageDoesNotExistError(package_name)

self._toml["tool"]["twyn"]["allowlist"].remove(package_name)
self._write_toml()
logger.warning(f"Package '{package_name}' successfully removed from allowlist")
new_config = ReadTwynConfiguration(
dependency_file=config.dependency_file,
selector_method=config.selector_method,
logging_level=config.logging_level,
allowlist=config.allowlist - {package_name},
)
self._write_config(toml, new_config)
logger.info(f"Package '{package_name}' successfully removed from allowlist")

def _get_read_config(self, toml: TOMLDocument) -> ReadTwynConfiguration:
"""Read the twyn configuration from a provided toml document."""
twyn_config_data = self._get_twyn_data_from_toml(toml)
return ReadTwynConfiguration(
dependency_file=twyn_config_data.get("dependency_file"),
selector_method=twyn_config_data.get("selector_method"),
logging_level=twyn_config_data.get("logging_level"),
allowlist=set(twyn_config_data.get("allowlist", set())),
)

def _write_config(self, toml: TOMLDocument, config: ReadTwynConfiguration) -> None:
"""Write the configuration to the toml file.

All null values are simply omitted from the toml file.
"""
twyn_toml_data = asdict(config, dict_factory=lambda x: _serialize_config(x))
if "tool" not in toml:
toml.add("tool", table())
if "twyn" not in toml["tool"]: # type: ignore[operator]
toml["tool"]["twyn"] = {} # type: ignore[index]
toml["tool"]["twyn"] = twyn_toml_data # type: ignore[index]
self._write_toml(toml)

def _read_toml(self) -> TOMLDocument:
try:
fp = self._get_toml_file_pointer()
except FileNotFoundError:
if not self._enforce_file and self._file_path == DEFAULT_PROJECT_TOML_FILE:
return TOMLDocument()
raise TOMLError(f"Error reading toml from {self._file_path}") from None

with open(fp, "r") as f:
content = parse(f.read())
return parse(dumps(content))

def _get_twyn_data(self) -> dict[str, Any]:
return self._toml.get("tool", {}).get("twyn", {})
def _get_twyn_data_from_toml(self, toml: TOMLDocument) -> dict[str, Any]:
return toml.get("tool", {}).get("twyn", {})

def _get_toml_file_pointer(self) -> Path:
"""Create a path for the toml file with the format <current working directory>/self.file_path."""
Expand All @@ -59,32 +152,35 @@ def _get_toml_file_pointer(self) -> Path:

return fp

def _write_toml(self) -> None:
def _write_toml(self, toml: TOMLDocument) -> None:
with open(self._get_toml_file_pointer(), "w") as f:
f.write(dumps(self._toml))

def _get_toml_as_dict(self) -> dict[str, Any]:
"""Read TOML into a dictionary."""
try:
fp = self._get_toml_file_pointer()
except FileNotFoundError:
if not self._enforce_file and self._file_path == DEFAULT_PROJECT_TOML_FILE:
return {}
raise

with open(fp, "r") as f:
content = parse(f.read())
return parse(dumps(content))

def _create_allowlist_in_toml_if_not_exists(self) -> None:
try:
isinstance(self._toml["tool"]["twyn"]["allowlist"], list)
except KeyError:
if "tool" not in self._toml:
self._toml["tool"] = {}

if "twyn" not in self._toml["tool"]:
self._toml["tool"]["twyn"] = {}

if "allowlist" not in self._toml["tool"]["twyn"]:
self._toml["tool"]["twyn"]["allowlist"] = []
try:
f.write(dumps(toml))
except Exception:
logger.exception("Error writing toml file")
raise TOMLError(f"Error writing toml to {self._file_path}") from None


def _get_logging_level(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be inside the class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it does not use self at all, I just put it here.

cli_verbosity: AvailableLoggingLevels,
config_logging_level: Optional[str],
) -> AvailableLoggingLevels:
"""Return the appropriate logging level, considering that the one in config has less priority than the one passed directly."""
if cli_verbosity is AvailableLoggingLevels.none:
if config_logging_level:
return AvailableLoggingLevels[config_logging_level.lower()]
else:
# default logging level
return AvailableLoggingLevels.warning
return cli_verbosity


def _serialize_config(x):
def _value_to_for_config(v):
if isinstance(v, Enum):
return v.name
elif isinstance(v, set):
return list(v)
return v

return {k: _value_to_for_config(v) for (k, v) in x if v is not None and v != set()}
6 changes: 5 additions & 1 deletion src/twyn/core/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from twyn.base.exceptions import TwynError


class TOMLError(TwynError):
def __init__(self, message: str):
super().__init__(message)


class AllowlistError(TwynError):
def __init__(self, package_name: str = ""):
message = self.message.format(package_name) if package_name else self.message
Expand All @@ -13,4 +18,3 @@ class AllowlistPackageAlreadyExistsError(AllowlistError):

class AllowlistPackageDoesNotExistError(AllowlistError):
message = "Package '{}' is not present in the allowlist. Skipping."

60 changes: 10 additions & 50 deletions src/twyn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from rich.progress import track

from twyn.base.constants import (
DEFAULT_SELECTOR_METHOD,
SELECTOR_METHOD_MAPPING,
AvailableLoggingLevels,
)
Expand Down Expand Up @@ -36,17 +35,20 @@ def check_dependencies(
verbosity: AvailableLoggingLevels = AvailableLoggingLevels.none,
) -> bool:
"""Check if dependencies could be typosquats."""
config = get_configuration(config_file, dependency_file, selector_method, verbosity)
config = ConfigHandler(file_path=config_file, enforce_file=False).resolve_config(
verbosity=verbosity, selector_method=selector_method, dependency_file=dependency_file
)
_set_logging_level(config.logging_level)

trusted_packages = TrustedPackages(
names=TopPyPiReference().get_packages(),
algorithm=EditDistance(),
selector=get_candidate_selector(config.selector_method),
threshold_class=SimilarityThreshold,
)
normalized_allowlist_packages = normalize_packages(config.allowlist)
normalized_allowlist_packages = _normalize_packages(config.allowlist)
dependencies = dependencies_cli if dependencies_cli else get_parsed_dependencies_from_file(config.dependency_file)
normalized_dependencies = normalize_packages(dependencies)
normalized_dependencies = _normalize_packages(dependencies)

errors: list[TyposquatCheckResult] = []
for dependency in track(normalized_dependencies, description="Processing..."):
Expand All @@ -67,56 +69,14 @@ def check_dependencies(
return bool(errors)


def get_configuration(
config_file: Optional[str],
dependency_file: Optional[str],
selector_method: Optional[str],
verbosity: AvailableLoggingLevels,
) -> ConfigHandler:
"""Read configuration and return configuration object.

Selects the appropriate values based on priorities between those in the file, and those directly provided.
"""
# Read config from file
config = ConfigHandler(file_path=config_file, enforce_file=False)

# Set logging level
config.logging_level = get_logging_level(
logging_level=verbosity,
config_logging_level=config.logging_level,
)
set_logging_level(config.logging_level)
# Set selector method according to priority order
config.selector_method = selector_method or config.selector_method or DEFAULT_SELECTOR_METHOD

# Set dependency file according to priority order
config.dependency_file = dependency_file or config.dependency_file or None
return config


def get_logging_level(
logging_level: AvailableLoggingLevels,
config_logging_level: Optional[str],
) -> AvailableLoggingLevels:
"""Return the appropriate logging level, considering that the one in config has less priority than the one passed directly."""
if logging_level is AvailableLoggingLevels.none:
if config_logging_level:
return AvailableLoggingLevels[config_logging_level.lower()]
else:
# default logging level
return AvailableLoggingLevels.warning

return logging_level


def set_logging_level(logging_level: AvailableLoggingLevels) -> None:
def _set_logging_level(logging_level: AvailableLoggingLevels) -> None:
logger.setLevel(logging_level.value)
logger.debug(f"Logging level: {logging_level.value}")


def get_candidate_selector(selector_method_name: Optional[str]) -> AbstractSelector:
def get_candidate_selector(selector_method_name: str) -> AbstractSelector:
logger.debug(f"Selector method received {selector_method_name}")
selector_method_name = selector_method_name or DEFAULT_SELECTOR_METHOD
selector_method_name = selector_method_name
selector_method = SELECTOR_METHOD_MAPPING[selector_method_name]()
logger.debug(f"Instantiated {selector_method} selector")
return selector_method
Expand All @@ -129,6 +89,6 @@ def get_parsed_dependencies_from_file(dependency_file: Optional[str] = None) ->
return dependencies


def normalize_packages(packages: set[str]) -> set[str]:
def _normalize_packages(packages: set[str]) -> set[str]:
"""Normalize dependency names according to PyPi https://packaging.python.org/en/latest/specifications/name-normalization/."""
return {re.sub(r"[-_.]+", "-", name).lower() for name in packages}
Loading
Loading