Skip to content

Commit 2f4d486

Browse files
committed
fix test and filter url kwargs by none
1 parent 1689cd5 commit 2f4d486

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

patchwork/step.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def __init__(self, inputs: DataPoint):
7373
self.run = self.__managed_run
7474

7575
def __init_subclass__(cls, input_class: Optional[Type] = None, output_class: Optional[Type] = None, **kwargs):
76+
if cls.__name__ == "PreparePR":
77+
print(1)
7678
input_class = input_class or getattr(cls, "input_class", None)
7779
if input_class is not None and not is_typeddict(input_class):
7880
input_class = None
@@ -81,14 +83,14 @@ def __init_subclass__(cls, input_class: Optional[Type] = None, output_class: Opt
8183
if output_class is not None and not is_typeddict(output_class):
8284
output_class = None
8385

84-
cls.__input_class = input_class
85-
cls.__output_class = output_class
86+
cls._input_class = input_class
87+
cls._output_class = output_class
8688

8789
@classmethod
8890
def find_missing_inputs(cls, inputs: DataPoint) -> Collection:
89-
if getattr(cls, "__input_class", None) is None:
91+
if getattr(cls, "_input_class", None) is None:
9092
return []
91-
return cls.__input_class.__required_keys__.difference(inputs.keys())
93+
return cls._input_class.__required_keys__.difference(inputs.keys())
9294

9395
def __managed_run(self, *args, **kwargs) -> Any:
9496
self.debug(self.inputs)

patchwork/steps/CallSQL/CallSQL.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,16 @@ def __build_engine(self, inputs: dict):
2020
driver = inputs.get("db_driver")
2121
dialect_plus_driver = f"{dialect}+{driver}" if driver is not None else dialect
2222
kwargs = dict(
23-
username=inputs["db_username"],
23+
username=inputs.get("db_username"),
2424
host=inputs.get("db_host", "localhost"),
2525
port=inputs.get("db_port", 5432),
26+
password=inputs.get("db_password"),
27+
database=inputs.get("db_database"),
28+
query=inputs.get("db_params"),
2629
)
27-
if inputs.get("db_password") is not None:
28-
kwargs["password"] = inputs.get("db_password")
29-
if inputs.get("db_name") is not None:
30-
kwargs["database"] = inputs.get("db_name")
31-
if inputs.get("db_params") is not None:
32-
kwargs["query"] = inputs.get("db_params")
3330
connection_url = URL.create(
3431
dialect_plus_driver,
35-
**kwargs,
32+
**{k: v for k, v in kwargs.items() if v is not None},
3633
)
3734

3835
connect_args = None

0 commit comments

Comments
 (0)