|
1 | 1 | # |
2 | 2 | # Copyright (c) 2023 Airbyte, Inc., all rights reserved. |
3 | 3 | # |
4 | | - |
| 4 | +import importlib |
5 | 5 | import logging |
| 6 | +import sys |
| 7 | +from dataclasses import dataclass |
6 | 8 | from enum import Flag, auto |
| 9 | +from functools import lru_cache |
| 10 | +from pathlib import Path |
| 11 | +from tempfile import TemporaryDirectory |
7 | 12 | from typing import Any, Callable, Dict, Generator, Mapping, Optional, cast |
8 | 13 |
|
| 14 | +from datamodel_code_generator import DataModelType, InputFileType, generate |
9 | 15 | from jsonschema import Draft7Validator, RefResolver, ValidationError, Validator, validators |
| 16 | +from pydantic import BaseModel |
10 | 17 |
|
11 | 18 | MAX_NESTING_DEPTH = 3 |
12 | 19 | json_to_python_simple = { |
@@ -275,3 +282,33 @@ def _get_type_structure(self, input_data: Any, current_depth: int = 0) -> Any: |
275 | 282 |
|
276 | 283 | else: |
277 | 284 | return python_to_json[type(input_data)] |
| 285 | + |
| 286 | + |
| 287 | +@dataclass(frozen=True) |
| 288 | +class PydanticTypeTransformer: |
| 289 | + @lru_cache |
| 290 | + def stream_model(self, json_schema: str) -> BaseModel: |
| 291 | + with TemporaryDirectory() as temporary_directory_name: |
| 292 | + temporary_directory = Path(temporary_directory_name) |
| 293 | + output = Path(temporary_directory / "models.py") |
| 294 | + generate( |
| 295 | + str(json_schema), |
| 296 | + input_file_type=InputFileType.Auto, |
| 297 | + input_filename="example.json", |
| 298 | + output=output, |
| 299 | + class_name="NormalizationModel", |
| 300 | + output_model_type=DataModelType.PydanticV2BaseModel, |
| 301 | + ) |
| 302 | + |
| 303 | + # Load the generated models.py dynamically |
| 304 | + spec = importlib.util.spec_from_file_location("models", output) |
| 305 | + module = importlib.util.module_from_spec(spec) |
| 306 | + sys.modules["models"] = module |
| 307 | + spec.loader.exec_module(module) |
| 308 | + |
| 309 | + normalization_model = getattr(module, "NormalizationModel") |
| 310 | + return normalization_model |
| 311 | + |
| 312 | + def transform(self, record: Dict[str, Any], schema: Mapping[str, Any]) -> None: |
| 313 | + model: BaseModel = self.stream_model(str(schema)) |
| 314 | + record.update(model(**record).model_dump()) |
0 commit comments