Skip to content

Commit 0a20c28

Browse files
committed
[DOP-25452] Allow creating HDFS and Hive connections without credentials
1 parent 59ad3c6 commit 0a20c28

27 files changed

+713
-142
lines changed

syncmaster/db/repositories/credentials_repository.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@
44

55
from typing import TYPE_CHECKING, NoReturn
66

7-
from sqlalchemy import ScalarResult, insert, select
8-
from sqlalchemy.exc import DBAPIError, IntegrityError, NoResultFound
7+
from sqlalchemy import ScalarResult, delete, insert, select
8+
from sqlalchemy.exc import DBAPIError, IntegrityError
99
from sqlalchemy.ext.asyncio import AsyncSession
1010

1111
from syncmaster.db.models import AuthData
1212
from syncmaster.db.repositories.base import Repository
1313
from syncmaster.db.repositories.utils import decrypt_auth_data, encrypt_auth_data
1414
from syncmaster.exceptions import SyncmasterError
15-
from syncmaster.exceptions.credentials import AuthDataNotFoundError
1615

1716
if TYPE_CHECKING:
1817
from syncmaster.scheduler.settings import SchedulerAppSettings
@@ -33,13 +32,13 @@ def __init__(
3332
async def read(
3433
self,
3534
connection_id: int,
36-
) -> dict:
35+
) -> dict | None:
3736
query = select(AuthData).where(AuthData.connection_id == connection_id)
38-
try:
39-
result: ScalarResult[AuthData] = await self._session.scalars(query)
40-
return decrypt_auth_data(result.one().value, settings=self._settings)
41-
except NoResultFound as e:
42-
raise AuthDataNotFoundError(f"Connection id = {connection_id}") from e
37+
result: ScalarResult[AuthData] = await self._session.scalars(query)
38+
result_row = result.one_or_none()
39+
if not result_row:
40+
return None
41+
return decrypt_auth_data(result_row.value, settings=self._settings)
4342

4443
async def read_bulk(
4544
self,
@@ -79,5 +78,17 @@ async def update(
7978
except IntegrityError as e:
8079
self._raise_error(e)
8180

81+
async def delete(
82+
self,
83+
connection_id: int,
84+
) -> AuthData:
85+
try:
86+
query = delete(AuthData).where(AuthData.connection_id == connection_id).returning(AuthData)
87+
result = await self._session.scalars(query)
88+
await self._session.flush()
89+
return result.one()
90+
except IntegrityError as e:
91+
self._raise_error(e)
92+
8293
def _raise_error(self, err: DBAPIError) -> NoReturn:
8394
raise SyncmasterError from err

syncmaster/db/repositories/run.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,17 @@ async def read_by_id(self, run_id: int) -> Run:
5252
async def create(
5353
self,
5454
transfer_id: int,
55-
source_creds: dict,
56-
target_creds: dict,
55+
source_auth_data: dict | None,
56+
target_auth_data: dict | None,
5757
type: RunType,
5858
) -> Run:
5959
run = Run()
6060
run.transfer_id = transfer_id
61-
run.transfer_dump = await self.read_full_serialized_transfer(transfer_id, source_creds, target_creds)
61+
run.transfer_dump = await self.read_full_serialized_transfer(
62+
transfer_id,
63+
source_auth_data,
64+
target_auth_data,
65+
)
6266
run.type = type
6367
try:
6468
self._session.add(run)
@@ -84,8 +88,8 @@ async def stop(self, run_id: int) -> Run:
8488
async def read_full_serialized_transfer(
8589
self,
8690
transfer_id: int,
87-
source_creds: dict,
88-
target_creds: dict,
91+
source_auth_data: dict | None,
92+
target_auth_data: dict | None,
8993
) -> dict[str, Any]:
9094
transfer = await self._session.scalars(
9195
select(Transfer)
@@ -116,15 +120,15 @@ async def read_full_serialized_transfer(
116120
name=transfer.source_connection.name,
117121
description=transfer.source_connection.description,
118122
data=transfer.source_connection.data,
119-
auth_data=source_creds["auth_data"],
123+
auth_data=source_auth_data,
120124
),
121125
target_connection=dict(
122126
id=transfer.target_connection.id,
123127
group_id=transfer.target_connection.group_id,
124128
name=transfer.target_connection.name,
125129
description=transfer.target_connection.description,
126130
data=transfer.target_connection.data,
127-
auth_data=target_creds["auth_data"],
131+
auth_data=target_auth_data,
128132
),
129133
)
130134

syncmaster/dto/connections.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ class HiveConnectionDTO(ConnectionDTO):
7575

7676
@dataclass
7777
class HDFSConnectionDTO(ConnectionDTO):
78-
user: str
79-
password: str
8078
cluster: str
79+
user: str | None = None
80+
password: str | None = None
8181
type: ClassVar[str] = "hdfs"
8282

8383

syncmaster/scheduler/transfer_job_manager.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,25 @@ async def send_job_to_celery(transfer_id: int) -> None: # noqa: WPS602, WPS217
8383
except TransferNotFoundError:
8484
return
8585

86-
credentials_source = await unit_of_work.credentials.read(transfer.source_connection_id)
87-
credentials_target = await unit_of_work.credentials.read(transfer.target_connection_id)
86+
source_auth_data: dict | None = None
87+
source_credentials = await unit_of_work.credentials.read(transfer.source_connection_id)
88+
if source_credentials:
89+
# remove secrets from the dump
90+
source_credentials_filtered = ReadAuthDataSchema.model_validate(source_credentials)
91+
source_auth_data = source_credentials_filtered.auth_data.model_dump()
92+
93+
target_auth_data: dict | None = None
94+
target_credentials = await unit_of_work.credentials.read(transfer.target_connection_id)
95+
if target_credentials:
96+
# remove secrets from the dump
97+
target_credentials_filtered = ReadAuthDataSchema.model_validate(target_credentials)
98+
target_auth_data = target_credentials_filtered.auth_data.model_dump()
8899

89100
async with unit_of_work:
90101
run = await unit_of_work.run.create(
91102
transfer_id=transfer_id,
92-
source_creds=ReadAuthDataSchema(auth_data=credentials_source).model_dump(),
93-
target_creds=ReadAuthDataSchema(auth_data=credentials_target).model_dump(),
103+
source_auth_data=source_auth_data,
104+
target_auth_data=target_auth_data,
94105
type=RunType.SCHEDULED,
95106
)
96107

syncmaster/schemas/v1/connections/hdfs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class CreateHDFSConnectionSchema(CreateConnectionBaseSchema):
3232
"Data required to connect to the HDFS cluster. These are the parameters that are specified in the URL request."
3333
),
3434
)
35-
auth_data: CreateBasicAuthSchema = Field(
35+
auth_data: CreateBasicAuthSchema | None = Field(
3636
description="Credentials for authorization",
3737
)
3838

@@ -44,6 +44,6 @@ class ReadHDFSConnectionSchema(ReadConnectionBaseSchema):
4444

4545

4646
class UpdateHDFSConnectionSchema(CreateHDFSConnectionSchema):
47-
auth_data: UpdateBasicAuthSchema = Field(
47+
auth_data: UpdateBasicAuthSchema | None = Field(
4848
description="Credentials for authorization",
4949
)

syncmaster/schemas/v1/connections/hive.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ class CreateHiveConnectionSchema(CreateConnectionBaseSchema):
3232
"Data required to connect to the database. These are the parameters that are specified in the URL request."
3333
),
3434
)
35-
auth_data: CreateBasicAuthSchema = Field(
35+
auth_data: CreateBasicAuthSchema | None = Field(
3636
description="Credentials for authorization",
3737
)
3838

@@ -44,6 +44,6 @@ class ReadHiveConnectionSchema(ReadConnectionBaseSchema):
4444

4545

4646
class UpdateHiveConnectionSchema(CreateHiveConnectionSchema):
47-
auth_data: UpdateBasicAuthSchema = Field(
47+
auth_data: UpdateBasicAuthSchema | None = Field(
4848
description="Credentials for authorization",
4949
)

syncmaster/server/api/v1/connections.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,11 @@ async def create_connection(
125125
data=connection_data.data.model_dump(),
126126
)
127127

128-
await unit_of_work.credentials.create(
129-
connection_id=connection.id,
130-
data=connection_data.auth_data.model_dump(),
131-
)
128+
if connection_data.auth_data:
129+
await unit_of_work.credentials.create(
130+
connection_id=connection.id,
131+
data=connection_data.auth_data.model_dump(),
132+
)
132133

133134
credentials = await unit_of_work.credentials.read(connection.id)
134135
return TypeAdapter(ReadConnectionSchema).validate_python(
@@ -183,7 +184,7 @@ async def read_connection(
183184

184185

185186
@router.put("/connections/{connection_id}")
186-
async def update_connection( # noqa: WPS217, WPS238
187+
async def update_connection( # noqa: WPS217, WPS238, WPS231
187188
connection_id: int,
188189
connection_data: UpdateConnectionSchema,
189190
current_user: User = Depends(get_user(is_active=True)),
@@ -200,36 +201,49 @@ async def update_connection( # noqa: WPS217, WPS238
200201
if resource_role < Permission.WRITE:
201202
raise ActionNotAllowedError
202203

203-
async with unit_of_work:
204-
existing_connection: Connection = await unit_of_work.connection.read_by_id(connection_id=connection_id)
205-
if connection_data.type != existing_connection.type:
206-
linked_transfers: Sequence[Transfer] = await unit_of_work.transfer.list_by_connection_id(connection_id)
207-
if linked_transfers:
208-
raise ConnectionTypeUpdateError
209-
210-
existing_credentials = await unit_of_work.credentials.read(connection_id=connection_id)
211-
auth_data = connection_data.auth_data.model_dump()
204+
existing_connection: Connection = await unit_of_work.connection.read_by_id(connection_id=connection_id)
205+
if connection_data.type != existing_connection.type:
206+
linked_transfers: Sequence[Transfer] = await unit_of_work.transfer.list_by_connection_id(connection_id)
207+
if linked_transfers:
208+
raise ConnectionTypeUpdateError
209+
210+
existing_credentials = await unit_of_work.credentials.read(connection_id=connection_id)
211+
new_credentials: dict | None = None
212+
if connection_data.auth_data:
213+
new_credentials = connection_data.auth_data.model_dump()
212214
secret_field = connection_data.auth_data.secret_field
215+
if new_credentials[secret_field] is None:
213216

214-
if auth_data[secret_field] is None:
215-
if existing_credentials["type"] != auth_data["type"]:
217+
# We don't return secret_field to client, so default field value means using existing secret
218+
if not existing_credentials or existing_credentials["type"] != new_credentials["type"]:
216219
raise ConnectionAuthDataUpdateError
217220

218-
auth_data[secret_field] = existing_credentials[secret_field]
221+
new_credentials[secret_field] = existing_credentials[secret_field]
219222

223+
async with unit_of_work:
220224
connection = await unit_of_work.connection.update(
221225
connection_id=connection_id,
222226
name=connection_data.name,
223227
type=connection_data.type,
224228
description=connection_data.description,
225229
data=connection_data.data.model_dump(),
226230
)
227-
await unit_of_work.credentials.update(
228-
connection_id=connection_id,
229-
data=auth_data,
230-
)
231231

232-
credentials = await unit_of_work.credentials.read(connection_id)
232+
if existing_credentials and new_credentials:
233+
await unit_of_work.credentials.update(
234+
connection_id=connection_id,
235+
data=new_credentials,
236+
)
237+
elif new_credentials:
238+
await unit_of_work.credentials.create(
239+
connection_id=connection.id,
240+
data=new_credentials,
241+
)
242+
elif existing_credentials:
243+
await unit_of_work.credentials.delete(
244+
connection_id=connection_id,
245+
)
246+
233247
return TypeAdapter(ReadConnectionSchema).validate_python(
234248
{
235249
"id": connection.id,
@@ -238,7 +252,7 @@ async def update_connection( # noqa: WPS217, WPS238
238252
"description": connection.description,
239253
"type": connection.type,
240254
"data": connection.data,
241-
"auth_data": credentials,
255+
"auth_data": new_credentials,
242256
},
243257
)
244258

syncmaster/server/api/v1/runs.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -103,20 +103,25 @@ async def start_run( # noqa: WPS217
103103

104104
# The credentials.read method is used rather than credentials.read_bulk deliberately
105105
# it’s more convenient to transfer credits in this place
106-
credentials_source = await unit_of_work.credentials.read(
107-
transfer.source_connection_id,
108-
)
109-
credentials_target = await unit_of_work.credentials.read(
110-
transfer.target_connection_id,
111-
)
106+
source_auth_data: dict | None = None
107+
source_credentials = await unit_of_work.credentials.read(transfer.source_connection_id)
108+
if source_credentials:
109+
# remove secrets from the dump
110+
source_credentials_filtered = ReadAuthDataSchema.model_validate(source_credentials)
111+
source_auth_data = source_credentials_filtered.auth_data.model_dump()
112+
113+
target_auth_data: dict | None = None
114+
target_credentials = await unit_of_work.credentials.read(transfer.target_connection_id)
115+
if target_credentials:
116+
# remove secrets from the dump
117+
target_credentials_filtered = ReadAuthDataSchema.model_validate(target_credentials)
118+
target_auth_data = target_credentials_filtered.auth_data.model_dump()
112119

113120
async with unit_of_work:
114121
run = await unit_of_work.run.create(
115122
transfer_id=create_run_data.transfer_id,
116-
# Since fields with credentials may have different names (for example, S3 and Postgres have different names)
117-
# the work of checking fields and removing passwords is delegated to the ReadAuthDataSchema class
118-
source_creds=ReadAuthDataSchema(auth_data=credentials_source).model_dump(),
119-
target_creds=ReadAuthDataSchema(auth_data=credentials_target).model_dump(),
123+
source_auth_data=source_auth_data,
124+
target_auth_data=target_auth_data,
120125
type=RunType.MANUAL,
121126
)
122127

syncmaster/worker/controller.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ def __init__(
156156
settings: WorkerAppSettings,
157157
run: Run,
158158
source_connection: Connection,
159-
source_auth_data: dict,
159+
source_auth_data: dict | None,
160160
target_connection: Connection,
161-
target_auth_data: dict,
161+
target_auth_data: dict | None,
162162
):
163163
self.temp_dir = TemporaryDirectory(prefix=f"syncmaster_{run.id}_")
164164

@@ -213,7 +213,7 @@ def perform_transfer(self) -> None:
213213
def get_handler(
214214
self,
215215
connection_data: dict[str, Any],
216-
connection_auth_data: dict,
216+
connection_auth_data: dict | None,
217217
run_data: dict[str, Any],
218218
transfer_id: int,
219219
transfer_params: dict[str, Any],
@@ -222,7 +222,7 @@ def get_handler(
222222
transformations: list[dict],
223223
temp_dir: TemporaryDirectory,
224224
) -> Handler:
225-
connection_data.update(connection_auth_data)
225+
connection_data.update(connection_auth_data or {})
226226
connection_data.pop("type")
227227
handler_type = transfer_params.pop("type", None)
228228

syncmaster/worker/handlers/file/hdfs.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ def connect(self, spark: SparkSession):
2323
spark=spark,
2424
).check()
2525

26-
self.file_connection = HDFS(
27-
cluster=self.connection_dto.cluster,
28-
).check()
26+
if self.connection_dto.user and self.connection_dto.password:
27+
self.file_connection = HDFS(
28+
cluster=self.connection_dto.cluster,
29+
user=self.connection_dto.user,
30+
password=self.connection_dto.password,
31+
).check()
32+
else:
33+
self.file_connection = HDFS(
34+
cluster=self.connection_dto.cluster,
35+
).check()

0 commit comments

Comments
 (0)