Skip to content

Commit 99ceb04

Browse files
committed
[DOP-29475] Add table name validation
1 parent dc693c3 commit 99ceb04

File tree

5 files changed

+116
-39
lines changed

5 files changed

+116
-39
lines changed

syncmaster/schemas/v1/transfers/db.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
# SPDX-License-Identifier: Apache-2.0
33
from __future__ import annotations
44

5-
from pydantic import BaseModel
5+
import re
6+
from typing import ClassVar
7+
8+
from pydantic import BaseModel, Field, field_validator
69

710
from syncmaster.schemas.v1.connection_types import (
811
CLICKHOUSE_TYPE,
@@ -16,7 +19,16 @@
1619

1720

1821
class DBTransfer(BaseModel):
19-
table_name: str
22+
TABLE_NAME_PATTERN: ClassVar[str] = r"^[\w\d]+\.[\w\d]+$"
23+
table_name: str = Field(description="Table name", json_schema_extra={"pattern": TABLE_NAME_PATTERN})
24+
25+
# make error message more user friendly
26+
@field_validator("table_name", mode="before")
27+
@classmethod
28+
def _table_name_is_qualified(cls, value):
29+
if not re.match(cls.TABLE_NAME_PATTERN, value):
30+
raise ValueError("Table name should be in format myschema.mytable")
31+
return value
2032

2133

2234
class HiveTransferSourceOrTarget(DBTransfer):
@@ -46,6 +58,17 @@ class MySQLTransferSourceOrTarget(DBTransfer):
4658
class IcebergTransferSourceOrTarget(DBTransfer):
4759
type: ICEBERG_TYPE
4860

61+
TABLE_NAME_PATTERN: ClassVar[str] = r"^[\w\d]+(\.[\w\d]+)+$"
62+
table_name: str = Field(description="Table name", json_schema_extra={"pattern": TABLE_NAME_PATTERN})
63+
64+
# make error message more user friendly
65+
@field_validator("table_name", mode="before")
66+
@classmethod
67+
def _table_name_is_qualified(cls, value):
68+
if not re.match(cls.TABLE_NAME_PATTERN, value):
69+
raise ValueError("Table name should be in format myschema.mytable")
70+
return value
71+
4972

5073
DBTransferSourceOrTarget = (
5174
ClickhouseTransferSourceOrTarget

syncmaster/schemas/v1/transfers/file/base.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,10 @@
2626
# At the moment the CreateTransferSourceParams and CreateTransferTargetParams
2727
# classes are identical but may change in the future
2828
class FileTransferSource(BaseModel):
29-
directory_path: str
29+
directory_path: str = Field(
30+
description="Absolute path to directory",
31+
json_schema_extra={"pattern": r"^/[\w\d ]+(/[\w\d ]+)*$"},
32+
)
3033
file_format: SOURCE_FILE_FORMAT = Field(discriminator="type")
3134
options: dict[str, Any] = Field(default_factory=dict)
3235

@@ -41,18 +44,22 @@ def _directory_path_is_valid_path(cls, value):
4144

4245

4346
class FileTransferTarget(BaseModel):
44-
directory_path: str
47+
FILE_NAME_PATTERN: ClassVar[str] = r"^[\w.{}-]+$"
48+
49+
directory_path: str = Field(
50+
description="Absolute path to directory",
51+
json_schema_extra={"pattern": r"^/[\w\d ]+(/[\w\d ]+)*$"},
52+
)
4553
file_format: TARGET_FILE_FORMAT = Field(discriminator="type")
4654
file_name_template: str = Field(
4755
default="{run_created_at}-{index}.{extension}",
4856
description="Template for file naming with required placeholders 'index' and 'extension'",
57+
json_schema_extra={"pattern": FILE_NAME_PATTERN},
4958
)
5059
options: dict[str, Any] = Field(default_factory=dict)
5160

5261
model_config = ConfigDict(arbitrary_types_allowed=True)
5362

54-
FILE_NAME_PATTERN: ClassVar[re.Pattern] = re.compile(r"^[\w.{}-]+$")
55-
5663
@field_validator("directory_path", mode="before")
5764
@classmethod
5865
def _directory_path_is_valid_path(cls, value):
@@ -63,7 +70,8 @@ def _directory_path_is_valid_path(cls, value):
6370
@field_validator("file_name_template")
6471
@classmethod
6572
def _validate_file_name_template(cls, value: str) -> str: # noqa: WPS238
66-
if not cls.FILE_NAME_PATTERN.match(value):
73+
# make error message more user friendly
74+
if not re.match(cls.FILE_NAME_PATTERN, value):
6775
raise ValueError("Template contains invalid characters. Allowed: letters, numbers, '.', '_', '-', '{', '}'")
6876

6977
required_keys = {"index", "extension"}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import pytest
2+
from httpx import AsyncClient
3+
4+
from syncmaster.db.models import Queue
5+
from tests.mocks import MockConnection, MockGroup, UserTestRoles
6+
7+
pytestmark = [pytest.mark.asyncio, pytest.mark.server]
8+
9+
10+
@pytest.mark.parametrize(
11+
"create_connection_data",
12+
[
13+
{
14+
"type": "postgres",
15+
"host": "localhost",
16+
"port": 5432,
17+
},
18+
],
19+
indirect=True,
20+
)
21+
@pytest.mark.parametrize(
22+
"target_source_params",
23+
[
24+
{
25+
"type": "postgres",
26+
"table_name": "table",
27+
},
28+
],
29+
)
30+
async def test_cannot_create_db_transfer_with_short_table_name(
31+
client: AsyncClient,
32+
two_group_connections: tuple[MockConnection, MockConnection],
33+
group_queue: Queue,
34+
mock_group: MockGroup,
35+
target_source_params: dict,
36+
create_connection_data: dict,
37+
):
38+
first_connection, second_connection = two_group_connections
39+
user = mock_group.get_member_of_role(UserTestRoles.Developer)
40+
41+
response = await client.post(
42+
"v1/transfers",
43+
headers={"Authorization": f"Bearer {user.token}"},
44+
json={
45+
"group_id": mock_group.group.id,
46+
"name": "new test transfer",
47+
"source_connection_id": first_connection.id,
48+
"target_connection_id": second_connection.id,
49+
"source_params": target_source_params,
50+
"target_params": target_source_params,
51+
"queue_id": group_queue.id,
52+
},
53+
)
54+
55+
assert response.status_code == 422, response.text
56+
assert response.json() == {
57+
"error": {
58+
"code": "invalid_request",
59+
"message": "Invalid request",
60+
"details": [
61+
{
62+
"context": {},
63+
"input": "table",
64+
"location": ["body", "source_params", "postgres", "table_name"],
65+
"message": "Value error, Table name should be in format myschema.mytable",
66+
"code": "value_error",
67+
},
68+
{
69+
"context": {},
70+
"input": "table",
71+
"location": ["body", "target_params", "postgres", "table_name"],
72+
"message": "Value error, Table name should be in format myschema.mytable",
73+
"code": "value_error",
74+
},
75+
],
76+
},
77+
}

tests/test_unit/test_transfers/test_file_transfers/test_create_transfer.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -400,37 +400,6 @@ async def test_developer_plus_can_create_hdfs_transfer(
400400
"type": "csv",
401401
},
402402
},
403-
{
404-
"type": "s3",
405-
"directory_path": "some/path",
406-
"file_format": {
407-
"type": "excel",
408-
"include_header": True,
409-
},
410-
},
411-
{
412-
"type": "s3",
413-
"directory_path": "some/path",
414-
"file_format": {
415-
"type": "xml",
416-
"root_tag": "data",
417-
"row_tag": "record",
418-
},
419-
},
420-
{
421-
"type": "s3",
422-
"directory_path": "some/path",
423-
"file_format": {
424-
"type": "orc",
425-
},
426-
},
427-
{
428-
"type": "s3",
429-
"directory_path": "some/path",
430-
"file_format": {
431-
"type": "parquet",
432-
},
433-
},
434403
],
435404
)
436405
async def test_cannot_create_file_transfer_with_relative_path(

tests/test_unit/test_transfers/test_file_transfers/test_update_transfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@
143143
],
144144
indirect=["create_connection_data"],
145145
)
146-
async def test_developer_plus_can_update_s3_transfer(
146+
async def test_developer_plus_can_update_ftp_transfer(
147147
client: AsyncClient,
148148
group_transfer: MockTransfer,
149149
role_developer_plus: UserTestRoles,

0 commit comments

Comments
 (0)