Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions kedro/validation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
"""Kedro validation framework."""

from .exceptions import ModelInstantiationError, ValidationError
from .source_filters import ParameterSourceFilter, SourceFilter
from .utils import is_pydantic_class, is_pydantic_model

__all__ = [
"ModelInstantiationError",
"ParameterSourceFilter",
"SourceFilter",
"ValidationError",
"is_pydantic_class",
"is_pydantic_model",
Expand Down
37 changes: 37 additions & 0 deletions kedro/validation/source_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Source filters for type extraction from different source types."""

from __future__ import annotations

from abc import ABC, abstractmethod


class SourceFilter(ABC):
"""Abstract base class for source-specific filtering and key extraction."""

@abstractmethod
def should_process(self, source_name: str) -> bool:
"""Determine if this filter should process the given source name."""

@abstractmethod
def extract_key(self, source_name: str) -> str:
"""Extract the key from the source name."""

@abstractmethod
def get_log_message(self, key: str, type_name: str) -> str:
"""Generate appropriate log message for this source type."""


class ParameterSourceFilter(SourceFilter):
"""Filter for parameter sources (``params:*``)."""

def should_process(self, source_name: str) -> bool:
"""Check if source is a parameter source."""
return isinstance(source_name, str) and source_name.startswith("params:")

def extract_key(self, source_name: str) -> str:
"""Extract parameter key from ``params:key`` format."""
return source_name.split(":", 1)[1]

def get_log_message(self, key: str, type_name: str) -> str:
"""Generate parameter-specific log message."""
return f"Found parameter requirement: {key} -> {type_name}"
9 changes: 9 additions & 0 deletions tests/validation/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@

import dataclasses

import pytest

from kedro.validation.source_filters import ParameterSourceFilter


@dataclasses.dataclass
class SampleDataclass:
Expand All @@ -21,3 +25,8 @@ class SamplePydanticModel(BaseModel):
PYDANTIC_AVAILABLE = True
except ImportError:
PYDANTIC_AVAILABLE = False


@pytest.fixture
def source_filter():
return ParameterSourceFilter()
28 changes: 28 additions & 0 deletions tests/validation/test_source_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Tests for kedro.validation.source_filters."""

from __future__ import annotations


class TestParameterSourceFilter:
def test_should_process_params_prefix(self, source_filter):
assert source_filter.should_process("params:model_options") is True

def test_should_process_non_params(self, source_filter):
assert source_filter.should_process("companies") is False

def test_should_process_non_string(self, source_filter):
assert source_filter.should_process(123) is False

def test_extract_key(self, source_filter):
assert source_filter.extract_key("params:model_options") == "model_options"

def test_extract_key_nested(self, source_filter):
assert (
source_filter.extract_key("params:model_options.test_size")
== "model_options.test_size"
)

def test_get_log_message(self, source_filter):
msg = source_filter.get_log_message("model_options", "ModelOptions")
assert "model_options" in msg
assert "ModelOptions" in msg